| |
| """project_model.ipynb |
| |
| Automatically generated by Colab. |
| |
| Original file is located at |
| https://colab.research.google.com/drive/1oopkA5yIlfizFuhXOPmTK7MUNh3Qasa3 |
| """ |
|
|
| |
|
|
| |
| import torch, cv2, os |
| import numpy as np |
| from PIL import Image |
| from ultralytics import YOLO |
| from transformers import pipeline, DPTFeatureExtractor, DPTForDepthEstimation |
| from TTS.api import TTS |
| from huggingface_hub import login |
|
|
| |
| login(token=os.environ["HUGGING_FACE_HUB_TOKEN"]) |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| yolo_model = YOLO("yolov9c.pt") |
| depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device).eval() |
| depth_feat = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") |
|
|
| |
| whisper_pipe = pipeline( |
| "automatic-speech-recognition", |
| model="openai/whisper-small", |
| device=0 if torch.cuda.is_available() else -1 |
| ) |
|
|
| |
| gemma_pipe = pipeline( |
| "image-text-to-text", |
| model="google/gemma-3-4b-it", |
| device=0 if torch.cuda.is_available() else -1, |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
| ) |
|
|
| |
| tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC") |
|
|
| |
| |
| |
|
|
| class VisualQAState: |
| """ |
| Stores the current image context and chat history for follow-up questions. |
| """ |
| def __init__(self): |
| self.current_image: Image.Image = None |
| self.visual_context: str = "" |
| self.message_history = [] |
|
|
| def reset(self, image: Image.Image, visual_context: str): |
| """ |
| Called when a new image is uploaded. |
| Resets context and starts new message history. |
| """ |
| self.current_image = image |
| self.visual_context = visual_context |
| self.message_history = [{ |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": self.current_image}, |
| {"type": "text", "text": self.visual_context} |
| ] |
| }] |
|
|
| def add_question(self, question: str): |
| """ |
| Adds a follow-up question only if the last message was from assistant. |
| Ensures alternating user/assistant messages. |
| """ |
| if not self.message_history or self.message_history[-1]["role"] == "assistant": |
| self.message_history.append({ |
| "role": "user", |
| "content": [{"type": "text", "text": question}] |
| }) |
|
|
| def add_answer(self, answer: str): |
| """ |
| Appends the assistant's response to the conversation history. |
| """ |
| self.message_history.append({ |
| "role": "assistant", |
| "content": [{"type": "text", "text": answer}] |
| }) |
|
|
| |
| |
| |
|
|
| def generate_visual_context(pil_image: Image.Image) -> str: |
| """ |
| Processes the image to extract object labels, depth info, and locations. |
| Builds a natural language context description for use in prompting. |
| """ |
| |
| rgb_image = np.array(pil_image) |
| cv2_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) |
|
|
| |
| yolo_results = yolo_model.predict(cv2_image)[0] |
| boxes = yolo_results.boxes |
| class_names = yolo_model.names |
|
|
| |
| depth_inputs = depth_feat(images=pil_image, return_tensors="pt").to(device) |
| with torch.no_grad(): |
| depth_output = depth_model(**depth_inputs) |
| depth_map = depth_output.predicted_depth.squeeze().cpu().numpy() |
| depth_map_resized = cv2.resize(depth_map, (rgb_image.shape[1], rgb_image.shape[0])) |
|
|
| |
| shared_visual_context = [] |
| for box in boxes: |
| x1, y1, x2, y2 = map(int, box.xyxy[0]) |
| label = class_names[int(box.cls[0])] |
| conf = float(box.conf[0]) |
|
|
| |
| depth_crop = depth_map_resized[y1:y2, x1:x2] |
| avg_depth = float(depth_crop.mean()) if depth_crop.size > 0 else None |
|
|
| |
| x_center = (x1 + x2) / 2 |
| pos = "left" if x_center < rgb_image.shape[1] / 3 else "right" if x_center > 2 * rgb_image.shape[1] / 3 else "center" |
|
|
| shared_visual_context.append({ |
| "label": label, |
| "confidence": conf, |
| "avg_depth": avg_depth, |
| "position": pos |
| }) |
|
|
| |
| descriptions = [] |
| for obj in shared_visual_context: |
| d = f"{obj['avg_depth']:.1f} units" if obj["avg_depth"] else "unknown" |
| s = obj.get("position", "unknown") |
| c = obj.get("confidence", 0.0) |
| descriptions.append(f"a {obj['label']} ({c:.2f} confidence) is at {d} on the {s}") |
|
|
| return "In the image, " + ", ".join(descriptions) + "." |
|
|
| |
| |
| |
|
|
| |
| session = VisualQAState() |
|
|
| def process_inputs( |
| session: VisualQAState, |
| image: Image.Image = None, |
| question: str = "", |
| audio_path: str = None, |
| enable_tts: bool = True |
| ): |
| """ |
| Handles a new image upload or a follow-up question. |
| Combines image context, audio transcription, and text input to generate a GEMMA-based answer. |
| Optionally outputs audio using TTS. |
| """ |
|
|
| |
| if image: |
| visual_context = generate_visual_context(image) |
| session.reset(image, visual_context) |
|
|
| |
| if audio_path: |
| audio_text = whisper_pipe(audio_path)["text"] |
| question += " " + audio_text |
|
|
| |
| session.add_question(question) |
|
|
| |
| gemma_output = gemma_pipe(text=session.message_history, max_new_tokens=200) |
| answer = gemma_output[0]["generated_text"][-1]["content"] |
|
|
| |
| session.add_answer(answer) |
|
|
| |
| output_audio_path = "response.wav" |
| if enable_tts: |
| tts.tts_to_file(text=answer, file_path=output_audio_path) |
| else: |
| output_audio_path = None |
|
|
| return answer, output_audio_path |
|
|