MutimodalVisionAssistant / project_model.py
saa231's picture
Update project_model.py
9a71a48 verified
raw
history blame
6.2 kB
# -*- coding: utf-8 -*-
"""project_model.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1oopkA5yIlfizFuhXOPmTK7MUNh3Qasa3
"""
# project_module.py
# Import libraries for ML, CV, NLP, audio, and TTS
import torch, cv2, os
import numpy as np
from PIL import Image
from ultralytics import YOLO
from transformers import pipeline, DPTFeatureExtractor, DPTForDepthEstimation
from TTS.api import TTS
from huggingface_hub import login
# Authenticate to Hugging Face using environment token
login(token=os.environ["HUGGING_FACE_HUB_TOKEN"])
# Set device for computation (GPU if available)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load all models
yolo_model = YOLO("yolov9c.pt") # YOLOv9 for object detection
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device).eval() # MiDaS for depth
depth_feat = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") # Feature extractor for depth model
# Whisper for audio transcription
whisper_pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-small",
device=0 if torch.cuda.is_available() else -1
)
# GEMMA for image+text to text QA
gemma_pipe = pipeline(
"image-text-to-text",
model="google/gemma-3-4b-it",
device=0 if torch.cuda.is_available() else -1,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
)
# Text-to-speech
tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC")
# -------------------------------
# Session Management Class
# -------------------------------
class VisualQAState:
"""
Stores the current image context and chat history for follow-up questions.
"""
def __init__(self):
self.current_image: Image.Image = None
self.visual_context: str = ""
self.message_history = []
def reset(self, image: Image.Image, visual_context: str):
"""
Called when a new image is uploaded.
Resets context and starts new message history.
"""
self.current_image = image
self.visual_context = visual_context
self.message_history = [{
"role": "user",
"content": [
{"type": "image", "image": self.current_image},
{"type": "text", "text": self.visual_context}
]
}]
def add_question(self, question: str):
"""
Adds a follow-up text message to the chat.
"""
self.message_history.append({
"role": "user",
"content": [{"type": "text", "text": question}]
})
# -------------------------------
# Generate Context from Image
# -------------------------------
def generate_visual_context(pil_image: Image.Image) -> str:
"""
Processes the image to extract object labels, depth info, and locations.
Builds a natural language context description for use in prompting.
"""
# Convert to OpenCV and RGB formats
rgb_image = np.array(pil_image)
cv2_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
# Object detection using YOLO
yolo_results = yolo_model.predict(cv2_image)[0]
boxes = yolo_results.boxes
class_names = yolo_model.names
# Depth estimation using MiDaS
depth_inputs = depth_feat(images=pil_image, return_tensors="pt").to(device)
with torch.no_grad():
depth_output = depth_model(**depth_inputs)
depth_map = depth_output.predicted_depth.squeeze().cpu().numpy()
depth_map_resized = cv2.resize(depth_map, (rgb_image.shape[1], rgb_image.shape[0]))
# Extract contextual information for each object
shared_visual_context = []
for box in boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
label = class_names[int(box.cls[0])]
conf = float(box.conf[0])
# Compute average depth of object
depth_crop = depth_map_resized[y1:y2, x1:x2]
avg_depth = float(depth_crop.mean()) if depth_crop.size > 0 else None
# Determine object horizontal position
x_center = (x1 + x2) / 2
pos = "left" if x_center < rgb_image.shape[1] / 3 else "right" if x_center > 2 * rgb_image.shape[1] / 3 else "center"
shared_visual_context.append({
"label": label,
"confidence": conf,
"avg_depth": avg_depth,
"position": pos
})
# Convert context to a readable sentence
descriptions = []
for obj in shared_visual_context:
d = f"{obj['avg_depth']:.1f} units" if obj["avg_depth"] else "unknown"
s = obj.get("position", "unknown")
c = obj.get("confidence", 0.0)
descriptions.append(f"a {obj['label']} ({c:.2f} confidence) is at {d} on the {s}")
return "In the image, " + ", ".join(descriptions) + "."
# -------------------------------
# Main Multimodal Processing Function
# -------------------------------
def process_inputs(
session: VisualQAState,
image: Image.Image = None,
question: str = "",
audio_path: str = None,
enable_tts: bool = True
):
"""
Handles a new image upload or a follow-up question.
Combines image context, audio transcription, and text input to generate a GEMMA-based answer.
Optionally outputs audio using TTS.
"""
# If new image is provided, reset session and build new context
if image:
visual_context = generate_visual_context(image)
session.reset(image, visual_context)
# If user gave an audio clip, transcribe it and append to question
if audio_path:
audio_text = whisper_pipe(audio_path)["text"]
question += " " + audio_text
# Append question to conversation history
session.add_question(question)
# Generate response using GEMMA with full conversation history
gemma_output = gemma_pipe(text=session.message_history, max_new_tokens=200)
answer = gemma_output[0]["generated_text"][-1]["content"]
# If TTS is enabled, synthesize answer as speech
output_audio_path = "response.wav"
if enable_tts:
tts.tts_to_file(text=answer, file_path=output_audio_path)
else:
output_audio_path = None
return answer, output_audio_path