| |
|
|
| |
| import torch, cv2, os |
| import numpy as np |
| from PIL import Image |
| from ultralytics import YOLO |
| from transformers import pipeline, DPTFeatureExtractor, DPTForDepthEstimation |
| from TTS.api import TTS |
| from huggingface_hub import login |
|
|
| |
| login(token=os.environ["HUGGING_FACE_HUB_TOKEN"]) |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| yolo_model = YOLO("yolov9c.pt") |
| depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device).eval() |
| depth_feat = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") |
|
|
| |
| whisper_pipe = pipeline( |
| "automatic-speech-recognition", |
| model="openai/whisper-small", |
| device=0 if torch.cuda.is_available() else -1 |
| ) |
|
|
| |
| gemma_pipe = pipeline( |
| "image-to-text", |
| model="google/gemma-3-4b-it", |
| device=0 if torch.cuda.is_available() else -1, |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
| ) |
|
|
| |
| tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC") |
|
|
| |
| |
| |
|
|
| class VisualQAState: |
| def __init__(self): |
| self.current_image: Image.Image = None |
| self.visual_context: str = "" |
| self.message_history = [] |
|
|
| def reset(self, image: Image.Image, visual_context: str): |
| self.current_image = image |
| self.visual_context = visual_context |
| self.message_history = [] |
|
|
| def add_question(self, question: str): |
| self.message_history.append({"role": "user", "content": question}) |
|
|
| def add_answer(self, answer: str): |
| self.message_history.append({"role": "assistant", "content": answer}) |
|
|
| |
| |
| |
|
|
| def generate_visual_context(pil_image: Image.Image) -> str: |
| rgb_image = np.array(pil_image) |
| cv2_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) |
|
|
| yolo_results = yolo_model.predict(cv2_image)[0] |
| boxes = yolo_results.boxes |
| class_names = yolo_model.names |
|
|
| depth_inputs = depth_feat(images=pil_image, return_tensors="pt").to(device) |
| with torch.no_grad(): |
| depth_output = depth_model(**depth_inputs) |
| depth_map = depth_output.predicted_depth.squeeze().cpu().numpy() |
| depth_map_resized = cv2.resize(depth_map, (rgb_image.shape[1], rgb_image.shape[0])) |
|
|
| shared_visual_context = [] |
| for box in boxes: |
| x1, y1, x2, y2 = map(int, box.xyxy[0]) |
| label = class_names[int(box.cls[0])] |
| conf = float(box.conf[0]) |
|
|
| depth_crop = depth_map_resized[y1:y2, x1:x2] |
| avg_depth = float(depth_crop.mean()) if depth_crop.size > 0 else None |
|
|
| x_center = (x1 + x2) / 2 |
| pos = "left" if x_center < rgb_image.shape[1] / 3 else "right" if x_center > 2 * rgb_image.shape[1] / 3 else "center" |
|
|
| shared_visual_context.append({ |
| "label": label, |
| "confidence": conf, |
| "avg_depth": avg_depth, |
| "position": pos |
| }) |
|
|
| descriptions = [] |
| for obj in shared_visual_context: |
| d = f"{obj['avg_depth']:.1f} units" if obj["avg_depth"] else "unknown" |
| s = obj.get("position", "unknown") |
| c = obj.get("confidence", 0.0) |
| descriptions.append(f"a {obj['label']} ({c:.2f} confidence) is at {d} on the {s}") |
|
|
| return "In the image, " + ", ".join(descriptions) + "." |
|
|
| |
| |
| |
|
|
| session = VisualQAState() |
|
|
| def process_inputs( |
| session: VisualQAState, |
| image: Image.Image = None, |
| question: str = "", |
| audio_path: str = None, |
| enable_tts: bool = True |
| ): |
| if image: |
| visual_context = generate_visual_context(image) |
| session.reset(image, visual_context) |
|
|
| if audio_path: |
| audio_text = whisper_pipe(audio_path)["text"] |
| question += " " + audio_text |
|
|
| session.add_question(question) |
|
|
| prompt = f"{session.visual_context}\n\nUser Question: {question}" |
|
|
| gemma_output = gemma_pipe(prompt, max_new_tokens=200) |
| answer = gemma_output[0]["generated_text"] |
|
|
| session.add_answer(answer) |
|
|
| output_audio_path = "response.wav" |
| if enable_tts: |
| tts.tts_to_file(text=answer, file_path=output_audio_path) |
| else: |
| output_audio_path = None |
|
|
| return answer, output_audio_path |
|
|