File size: 6,195 Bytes
dd0f882
 
 
 
 
 
 
 
 
 
 
9a71a48
 
dd0f882
 
 
 
 
1bad0b9
 
9a71a48
9c1199b
1bad0b9
9a71a48
 
dd0f882
9a71a48
 
 
 
1bad0b9
9a71a48
 
 
 
 
 
 
 
dd0f882
 
 
 
 
 
 
9a71a48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd0f882
 
9a71a48
dd0f882
 
 
 
9a71a48
dd0f882
 
 
 
 
 
9a71a48
dd0f882
 
 
 
 
9a71a48
 
dd0f882
 
9a71a48
 
dd0f882
 
9a71a48
dd0f882
 
 
 
 
 
 
9a71a48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd0f882
 
9a71a48
dd0f882
9a71a48
 
 
 
dd0f882
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# -*- coding: utf-8 -*-
"""project_model.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1oopkA5yIlfizFuhXOPmTK7MUNh3Qasa3
"""

# project_module.py

# Import libraries for ML, CV, NLP, audio, and TTS
import torch, cv2, os
import numpy as np
from PIL import Image
from ultralytics import YOLO
from transformers import pipeline, DPTFeatureExtractor, DPTForDepthEstimation
from TTS.api import TTS
from huggingface_hub import login

# Authenticate to Hugging Face using environment token
login(token=os.environ["HUGGING_FACE_HUB_TOKEN"])

# Set device for computation (GPU if available)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load all models
yolo_model = YOLO("yolov9c.pt")  # YOLOv9 for object detection
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device).eval()  # MiDaS for depth
depth_feat = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")  # Feature extractor for depth model

# Whisper for audio transcription
whisper_pipe = pipeline(
    "automatic-speech-recognition",
    model="openai/whisper-small",
    device=0 if torch.cuda.is_available() else -1
)

# GEMMA for image+text to text QA
gemma_pipe = pipeline(
    "image-text-to-text",
    model="google/gemma-3-4b-it",
    device=0 if torch.cuda.is_available() else -1,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
)

# Text-to-speech
tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC")

# -------------------------------
# Session Management Class
# -------------------------------

class VisualQAState:
    """
    Stores the current image context and chat history for follow-up questions.
    """
    def __init__(self):
        self.current_image: Image.Image = None
        self.visual_context: str = ""
        self.message_history = []

    def reset(self, image: Image.Image, visual_context: str):
        """
        Called when a new image is uploaded.
        Resets context and starts new message history.
        """
        self.current_image = image
        self.visual_context = visual_context
        self.message_history = [{
            "role": "user",
            "content": [
                {"type": "image", "image": self.current_image},
                {"type": "text", "text": self.visual_context}
            ]
        }]

    def add_question(self, question: str):
        """
        Adds a follow-up text message to the chat.
        """
        self.message_history.append({
            "role": "user",
            "content": [{"type": "text", "text": question}]
        })

# -------------------------------
# Generate Context from Image
# -------------------------------

def generate_visual_context(pil_image: Image.Image) -> str:
    """
    Processes the image to extract object labels, depth info, and locations.
    Builds a natural language context description for use in prompting.
    """
    # Convert to OpenCV and RGB formats
    rgb_image = np.array(pil_image)
    cv2_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)

    # Object detection using YOLO
    yolo_results = yolo_model.predict(cv2_image)[0]
    boxes = yolo_results.boxes
    class_names = yolo_model.names

    # Depth estimation using MiDaS
    depth_inputs = depth_feat(images=pil_image, return_tensors="pt").to(device)
    with torch.no_grad():
        depth_output = depth_model(**depth_inputs)
    depth_map = depth_output.predicted_depth.squeeze().cpu().numpy()
    depth_map_resized = cv2.resize(depth_map, (rgb_image.shape[1], rgb_image.shape[0]))

    # Extract contextual information for each object
    shared_visual_context = []
    for box in boxes:
        x1, y1, x2, y2 = map(int, box.xyxy[0])
        label = class_names[int(box.cls[0])]
        conf = float(box.conf[0])

        # Compute average depth of object
        depth_crop = depth_map_resized[y1:y2, x1:x2]
        avg_depth = float(depth_crop.mean()) if depth_crop.size > 0 else None

        # Determine object horizontal position
        x_center = (x1 + x2) / 2
        pos = "left" if x_center < rgb_image.shape[1] / 3 else "right" if x_center > 2 * rgb_image.shape[1] / 3 else "center"

        shared_visual_context.append({
            "label": label,
            "confidence": conf,
            "avg_depth": avg_depth,
            "position": pos
        })

    # Convert context to a readable sentence
    descriptions = []
    for obj in shared_visual_context:
        d = f"{obj['avg_depth']:.1f} units" if obj["avg_depth"] else "unknown"
        s = obj.get("position", "unknown")
        c = obj.get("confidence", 0.0)
        descriptions.append(f"a {obj['label']} ({c:.2f} confidence) is at {d} on the {s}")

    return "In the image, " + ", ".join(descriptions) + "."

# -------------------------------
# Main Multimodal Processing Function
# -------------------------------

def process_inputs(
    session: VisualQAState,
    image: Image.Image = None,
    question: str = "",
    audio_path: str = None,
    enable_tts: bool = True
):
    """
    Handles a new image upload or a follow-up question.
    Combines image context, audio transcription, and text input to generate a GEMMA-based answer.
    Optionally outputs audio using TTS.
    """

    # If new image is provided, reset session and build new context
    if image:
        visual_context = generate_visual_context(image)
        session.reset(image, visual_context)

    # If user gave an audio clip, transcribe it and append to question
    if audio_path:
        audio_text = whisper_pipe(audio_path)["text"]
        question += " " + audio_text

    # Append question to conversation history
    session.add_question(question)

    # Generate response using GEMMA with full conversation history
    gemma_output = gemma_pipe(text=session.message_history, max_new_tokens=200)
    answer = gemma_output[0]["generated_text"][-1]["content"]

    # If TTS is enabled, synthesize answer as speech
    output_audio_path = "response.wav"
    if enable_tts:
        tts.tts_to_file(text=answer, file_path=output_audio_path)
    else:
        output_audio_path = None

    return answer, output_audio_path