MutimodalVisionAssistant / project_model.py
saa231's picture
Update project_model.py
fc25a0b verified
raw
history blame
6.89 kB
# -*- coding: utf-8 -*-
"""project_model.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1oopkA5yIlfizFuhXOPmTK7MUNh3Qasa3
"""
# project_module.py
# Import libraries for ML, CV, NLP, audio, and TTS
import torch, cv2, os
import numpy as np
from PIL import Image
from ultralytics import YOLO
from transformers import pipeline, DPTFeatureExtractor, DPTForDepthEstimation
from TTS.api import TTS
from huggingface_hub import login
# Authenticate to Hugging Face using environment token
login(token=os.environ["HUGGING_FACE_HUB_TOKEN"])
# Set device for computation (GPU if available)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load all models
yolo_model = YOLO("yolov9c.pt") # YOLOv9 for object detection
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device).eval() # MiDaS for depth
depth_feat = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") # Feature extractor for depth model
# Whisper for audio transcription
whisper_pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-small",
device=0 if torch.cuda.is_available() else -1
)
# GEMMA for image+text to text QA
gemma_pipe = pipeline(
"image-text-to-text",
model="google/gemma-3-4b-it",
device=0 if torch.cuda.is_available() else -1,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
)
# Text-to-speech
tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC")
# -------------------------------
# Session Management Class
# -------------------------------
class VisualQAState:
"""
Stores the current image context and chat history for follow-up questions.
"""
def __init__(self):
self.current_image: Image.Image = None
self.visual_context: str = ""
self.message_history = []
def reset(self, image: Image.Image, visual_context: str):
"""
Called when a new image is uploaded.
Resets context and starts new message history.
"""
self.current_image = image
self.visual_context = visual_context
self.message_history = [{
"role": "user",
"content": [
{"type": "image", "image": self.current_image},
{"type": "text", "text": self.visual_context}
]
}]
def add_question(self, question: str):
"""
Adds a follow-up question only if the last message was from assistant.
Ensures alternating user/assistant messages.
"""
if not self.message_history or self.message_history[-1]["role"] == "assistant":
self.message_history.append({
"role": "user",
"content": [{"type": "text", "text": question}]
})
def add_answer(self, answer: str):
"""
Appends the assistant's response to the conversation history.
"""
self.message_history.append({
"role": "assistant",
"content": [{"type": "text", "text": answer}]
})
# -------------------------------
# Generate Context from Image
# -------------------------------
def generate_visual_context(pil_image: Image.Image) -> str:
"""
Processes the image to extract object labels, depth info, and locations.
Builds a natural language context description for use in prompting.
"""
# Convert to OpenCV and RGB formats
rgb_image = np.array(pil_image)
cv2_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
# Object detection using YOLO
yolo_results = yolo_model.predict(cv2_image)[0]
boxes = yolo_results.boxes
class_names = yolo_model.names
# Depth estimation using MiDaS
depth_inputs = depth_feat(images=pil_image, return_tensors="pt").to(device)
with torch.no_grad():
depth_output = depth_model(**depth_inputs)
depth_map = depth_output.predicted_depth.squeeze().cpu().numpy()
depth_map_resized = cv2.resize(depth_map, (rgb_image.shape[1], rgb_image.shape[0]))
# Extract contextual information for each object
shared_visual_context = []
for box in boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
label = class_names[int(box.cls[0])]
conf = float(box.conf[0])
# Compute average depth of object
depth_crop = depth_map_resized[y1:y2, x1:x2]
avg_depth = float(depth_crop.mean()) if depth_crop.size > 0 else None
# Determine object horizontal position
x_center = (x1 + x2) / 2
pos = "left" if x_center < rgb_image.shape[1] / 3 else "right" if x_center > 2 * rgb_image.shape[1] / 3 else "center"
shared_visual_context.append({
"label": label,
"confidence": conf,
"avg_depth": avg_depth,
"position": pos
})
# Convert context to a readable sentence
descriptions = []
for obj in shared_visual_context:
d = f"{obj['avg_depth']:.1f} units" if obj["avg_depth"] else "unknown"
s = obj.get("position", "unknown")
c = obj.get("confidence", 0.0)
descriptions.append(f"a {obj['label']} ({c:.2f} confidence) is at {d} on the {s}")
return "In the image, " + ", ".join(descriptions) + "."
# -------------------------------
# Main Multimodal Processing Function
# -------------------------------
# Create a global session object to persist across follow-ups
session = VisualQAState()
def process_inputs(
session: VisualQAState,
image: Image.Image = None,
question: str = "",
audio_path: str = None,
enable_tts: bool = True
):
"""
Handles a new image upload or a follow-up question.
Combines image context, audio transcription, and text input to generate a GEMMA-based answer.
Optionally outputs audio using TTS.
"""
# If new image is provided, reset session and build new context
if image:
visual_context = generate_visual_context(image)
session.reset(image, visual_context)
# If user gave an audio clip, transcribe it and append to question
if audio_path:
audio_text = whisper_pipe(audio_path)["text"]
question += " " + audio_text
# Append question to conversation history (only if alternating correctly)
session.add_question(question)
# Generate response using GEMMA with full conversation history
gemma_output = gemma_pipe(text=session.message_history, max_new_tokens=200)
answer = gemma_output[0]["generated_text"][-1]["content"]
# Append GEMMA's response to the history to maintain alternating structure
session.add_answer(answer)
# If TTS is enabled, synthesize answer as speech
output_audio_path = "response.wav"
if enable_tts:
tts.tts_to_file(text=answer, file_path=output_audio_path)
else:
output_audio_path = None
return answer, output_audio_path