|
|
|
|
|
"""project_model.ipynb |
|
|
|
|
|
Automatically generated by Colab. |
|
|
|
|
|
Original file is located at |
|
|
https://colab.research.google.com/drive/1oopkA5yIlfizFuhXOPmTK7MUNh3Qasa3 |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch, cv2, os |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from ultralytics import YOLO |
|
|
from transformers import pipeline, DPTFeatureExtractor, DPTForDepthEstimation |
|
|
from TTS.api import TTS |
|
|
from huggingface_hub import login |
|
|
|
|
|
|
|
|
login(token=os.environ["HUGGING_FACE_HUB_TOKEN"]) |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
yolo_model = YOLO("yolov9c.pt") |
|
|
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device).eval() |
|
|
depth_feat = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") |
|
|
|
|
|
|
|
|
whisper_pipe = pipeline( |
|
|
"automatic-speech-recognition", |
|
|
model="openai/whisper-small", |
|
|
device=0 if torch.cuda.is_available() else -1 |
|
|
) |
|
|
|
|
|
|
|
|
gemma_pipe = pipeline( |
|
|
"image-text-to-text", |
|
|
model="google/gemma-3-4b-it", |
|
|
device=0 if torch.cuda.is_available() else -1, |
|
|
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
) |
|
|
|
|
|
|
|
|
tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VisualQAState: |
|
|
""" |
|
|
Stores the current image context and chat history for follow-up questions. |
|
|
""" |
|
|
def __init__(self): |
|
|
self.current_image: Image.Image = None |
|
|
self.visual_context: str = "" |
|
|
self.message_history = [] |
|
|
|
|
|
def reset(self, image: Image.Image, visual_context: str): |
|
|
""" |
|
|
Called when a new image is uploaded. |
|
|
Resets context and starts new message history. |
|
|
""" |
|
|
self.current_image = image |
|
|
self.visual_context = visual_context |
|
|
self.message_history = [{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "image", "image": self.current_image}, |
|
|
{"type": "text", "text": self.visual_context} |
|
|
] |
|
|
}] |
|
|
|
|
|
def add_question(self, question: str): |
|
|
""" |
|
|
Adds a follow-up text message to the chat. |
|
|
""" |
|
|
self.message_history.append({ |
|
|
"role": "user", |
|
|
"content": [{"type": "text", "text": question}] |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_visual_context(pil_image: Image.Image) -> str: |
|
|
""" |
|
|
Processes the image to extract object labels, depth info, and locations. |
|
|
Builds a natural language context description for use in prompting. |
|
|
""" |
|
|
|
|
|
rgb_image = np.array(pil_image) |
|
|
cv2_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
|
|
|
yolo_results = yolo_model.predict(cv2_image)[0] |
|
|
boxes = yolo_results.boxes |
|
|
class_names = yolo_model.names |
|
|
|
|
|
|
|
|
depth_inputs = depth_feat(images=pil_image, return_tensors="pt").to(device) |
|
|
with torch.no_grad(): |
|
|
depth_output = depth_model(**depth_inputs) |
|
|
depth_map = depth_output.predicted_depth.squeeze().cpu().numpy() |
|
|
depth_map_resized = cv2.resize(depth_map, (rgb_image.shape[1], rgb_image.shape[0])) |
|
|
|
|
|
|
|
|
shared_visual_context = [] |
|
|
for box in boxes: |
|
|
x1, y1, x2, y2 = map(int, box.xyxy[0]) |
|
|
label = class_names[int(box.cls[0])] |
|
|
conf = float(box.conf[0]) |
|
|
|
|
|
|
|
|
depth_crop = depth_map_resized[y1:y2, x1:x2] |
|
|
avg_depth = float(depth_crop.mean()) if depth_crop.size > 0 else None |
|
|
|
|
|
|
|
|
x_center = (x1 + x2) / 2 |
|
|
pos = "left" if x_center < rgb_image.shape[1] / 3 else "right" if x_center > 2 * rgb_image.shape[1] / 3 else "center" |
|
|
|
|
|
shared_visual_context.append({ |
|
|
"label": label, |
|
|
"confidence": conf, |
|
|
"avg_depth": avg_depth, |
|
|
"position": pos |
|
|
}) |
|
|
|
|
|
|
|
|
descriptions = [] |
|
|
for obj in shared_visual_context: |
|
|
d = f"{obj['avg_depth']:.1f} units" if obj["avg_depth"] else "unknown" |
|
|
s = obj.get("position", "unknown") |
|
|
c = obj.get("confidence", 0.0) |
|
|
descriptions.append(f"a {obj['label']} ({c:.2f} confidence) is at {d} on the {s}") |
|
|
|
|
|
return "In the image, " + ", ".join(descriptions) + "." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_inputs( |
|
|
session: VisualQAState, |
|
|
image: Image.Image = None, |
|
|
question: str = "", |
|
|
audio_path: str = None, |
|
|
enable_tts: bool = True |
|
|
): |
|
|
""" |
|
|
Handles a new image upload or a follow-up question. |
|
|
Combines image context, audio transcription, and text input to generate a GEMMA-based answer. |
|
|
Optionally outputs audio using TTS. |
|
|
""" |
|
|
|
|
|
|
|
|
if image: |
|
|
visual_context = generate_visual_context(image) |
|
|
session.reset(image, visual_context) |
|
|
|
|
|
|
|
|
if audio_path: |
|
|
audio_text = whisper_pipe(audio_path)["text"] |
|
|
question += " " + audio_text |
|
|
|
|
|
|
|
|
session.add_question(question) |
|
|
|
|
|
|
|
|
gemma_output = gemma_pipe(text=session.message_history, max_new_tokens=200) |
|
|
answer = gemma_output[0]["generated_text"][-1]["content"] |
|
|
|
|
|
|
|
|
output_audio_path = "response.wav" |
|
|
if enable_tts: |
|
|
tts.tts_to_file(text=answer, file_path=output_audio_path) |
|
|
else: |
|
|
output_audio_path = None |
|
|
|
|
|
return answer, output_audio_path |