# -*- coding: utf-8 -*- """project_model.ipynb Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/1oopkA5yIlfizFuhXOPmTK7MUNh3Qasa3 """ # project_module.py # Import libraries for ML, CV, NLP, audio, and TTS import torch, cv2, os import numpy as np from PIL import Image from ultralytics import YOLO from transformers import pipeline, DPTFeatureExtractor, DPTForDepthEstimation from TTS.api import TTS from huggingface_hub import login # Authenticate to Hugging Face using environment token login(token=os.environ["HUGGING_FACE_HUB_TOKEN"]) # Set device for computation (GPU if available) device = "cuda" if torch.cuda.is_available() else "cpu" # Load all models yolo_model = YOLO("yolov9c.pt") # YOLOv9 for object detection depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device).eval() # MiDaS for depth depth_feat = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") # Feature extractor for depth model # Whisper for audio transcription whisper_pipe = pipeline( "automatic-speech-recognition", model="openai/whisper-small", device=0 if torch.cuda.is_available() else -1 ) # GEMMA for image+text to text QA gemma_pipe = pipeline( "image-text-to-text", model="google/gemma-3-4b-it", device=0 if torch.cuda.is_available() else -1, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 ) # Text-to-speech tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC") # ------------------------------- # Session Management Class # ------------------------------- class VisualQAState: """ Stores the current image context and chat history for follow-up questions. """ def __init__(self): self.current_image: Image.Image = None self.visual_context: str = "" self.message_history = [] def reset(self, image: Image.Image, visual_context: str): """ Called when a new image is uploaded. Resets context and starts new message history. """ self.current_image = image self.visual_context = visual_context self.message_history = [{ "role": "user", "content": [ {"type": "image", "image": self.current_image}, {"type": "text", "text": self.visual_context} ] }] def add_question(self, question: str): """ Adds a follow-up text message to the chat. """ self.message_history.append({ "role": "user", "content": [{"type": "text", "text": question}] }) # ------------------------------- # Generate Context from Image # ------------------------------- def generate_visual_context(pil_image: Image.Image) -> str: """ Processes the image to extract object labels, depth info, and locations. Builds a natural language context description for use in prompting. """ # Convert to OpenCV and RGB formats rgb_image = np.array(pil_image) cv2_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) # Object detection using YOLO yolo_results = yolo_model.predict(cv2_image)[0] boxes = yolo_results.boxes class_names = yolo_model.names # Depth estimation using MiDaS depth_inputs = depth_feat(images=pil_image, return_tensors="pt").to(device) with torch.no_grad(): depth_output = depth_model(**depth_inputs) depth_map = depth_output.predicted_depth.squeeze().cpu().numpy() depth_map_resized = cv2.resize(depth_map, (rgb_image.shape[1], rgb_image.shape[0])) # Extract contextual information for each object shared_visual_context = [] for box in boxes: x1, y1, x2, y2 = map(int, box.xyxy[0]) label = class_names[int(box.cls[0])] conf = float(box.conf[0]) # Compute average depth of object depth_crop = depth_map_resized[y1:y2, x1:x2] avg_depth = float(depth_crop.mean()) if depth_crop.size > 0 else None # Determine object horizontal position x_center = (x1 + x2) / 2 pos = "left" if x_center < rgb_image.shape[1] / 3 else "right" if x_center > 2 * rgb_image.shape[1] / 3 else "center" shared_visual_context.append({ "label": label, "confidence": conf, "avg_depth": avg_depth, "position": pos }) # Convert context to a readable sentence descriptions = [] for obj in shared_visual_context: d = f"{obj['avg_depth']:.1f} units" if obj["avg_depth"] else "unknown" s = obj.get("position", "unknown") c = obj.get("confidence", 0.0) descriptions.append(f"a {obj['label']} ({c:.2f} confidence) is at {d} on the {s}") return "In the image, " + ", ".join(descriptions) + "." # ------------------------------- # Main Multimodal Processing Function # ------------------------------- def process_inputs( session: VisualQAState, image: Image.Image = None, question: str = "", audio_path: str = None, enable_tts: bool = True ): """ Handles a new image upload or a follow-up question. Combines image context, audio transcription, and text input to generate a GEMMA-based answer. Optionally outputs audio using TTS. """ # If new image is provided, reset session and build new context if image: visual_context = generate_visual_context(image) session.reset(image, visual_context) # If user gave an audio clip, transcribe it and append to question if audio_path: audio_text = whisper_pipe(audio_path)["text"] question += " " + audio_text # Append question to conversation history session.add_question(question) # Generate response using GEMMA with full conversation history gemma_output = gemma_pipe(text=session.message_history, max_new_tokens=200) answer = gemma_output[0]["generated_text"][-1]["content"] # If TTS is enabled, synthesize answer as speech output_audio_path = "response.wav" if enable_tts: tts.tts_to_file(text=answer, file_path=output_audio_path) else: output_audio_path = None return answer, output_audio_path