MutimodalVisionAssistant / project_model.py
saa231's picture
Update project_model.py
00aadd5 verified
raw
history blame
4.62 kB
# project_module.py
# Import libraries for ML, CV, NLP, audio, and TTS
import torch, cv2, os
import numpy as np
from PIL import Image
from ultralytics import YOLO
from transformers import pipeline, DPTFeatureExtractor, DPTForDepthEstimation
from TTS.api import TTS
from huggingface_hub import login
# Authenticate to Hugging Face using environment token
login(token=os.environ["HUGGING_FACE_HUB_TOKEN"])
# Set device for computation (GPU if available)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load all models
yolo_model = YOLO("yolov9c.pt") # YOLOv9 for object detection
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device).eval()
depth_feat = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
# Whisper for audio transcription
whisper_pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-small",
device=0 if torch.cuda.is_available() else -1
)
# GEMMA for image+text to text QA
gemma_pipe = pipeline(
"image-to-text",
model="google/gemma-3-4b-it",
device=0 if torch.cuda.is_available() else -1,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
)
# Text-to-speech
tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC")
# -------------------------------
# Session Management Class
# -------------------------------
class VisualQAState:
def __init__(self):
self.current_image: Image.Image = None
self.visual_context: str = ""
self.message_history = []
def reset(self, image: Image.Image, visual_context: str):
self.current_image = image
self.visual_context = visual_context
self.message_history = []
def add_question(self, question: str):
self.message_history.append({"role": "user", "content": question})
def add_answer(self, answer: str):
self.message_history.append({"role": "assistant", "content": answer})
# -------------------------------
# Generate Context from Image
# -------------------------------
def generate_visual_context(pil_image: Image.Image) -> str:
rgb_image = np.array(pil_image)
cv2_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
yolo_results = yolo_model.predict(cv2_image)[0]
boxes = yolo_results.boxes
class_names = yolo_model.names
depth_inputs = depth_feat(images=pil_image, return_tensors="pt").to(device)
with torch.no_grad():
depth_output = depth_model(**depth_inputs)
depth_map = depth_output.predicted_depth.squeeze().cpu().numpy()
depth_map_resized = cv2.resize(depth_map, (rgb_image.shape[1], rgb_image.shape[0]))
shared_visual_context = []
for box in boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
label = class_names[int(box.cls[0])]
conf = float(box.conf[0])
depth_crop = depth_map_resized[y1:y2, x1:x2]
avg_depth = float(depth_crop.mean()) if depth_crop.size > 0 else None
x_center = (x1 + x2) / 2
pos = "left" if x_center < rgb_image.shape[1] / 3 else "right" if x_center > 2 * rgb_image.shape[1] / 3 else "center"
shared_visual_context.append({
"label": label,
"confidence": conf,
"avg_depth": avg_depth,
"position": pos
})
descriptions = []
for obj in shared_visual_context:
d = f"{obj['avg_depth']:.1f} units" if obj["avg_depth"] else "unknown"
s = obj.get("position", "unknown")
c = obj.get("confidence", 0.0)
descriptions.append(f"a {obj['label']} ({c:.2f} confidence) is at {d} on the {s}")
return "In the image, " + ", ".join(descriptions) + "."
# -------------------------------
# Main Multimodal Processing Function
# -------------------------------
session = VisualQAState()
def process_inputs(
session: VisualQAState,
image: Image.Image = None,
question: str = "",
audio_path: str = None,
enable_tts: bool = True
):
if image:
visual_context = generate_visual_context(image)
session.reset(image, visual_context)
if audio_path:
audio_text = whisper_pipe(audio_path)["text"]
question += " " + audio_text
session.add_question(question)
prompt = f"{session.visual_context}\n\nUser Question: {question}"
gemma_output = gemma_pipe(prompt, max_new_tokens=200)
answer = gemma_output[0]["generated_text"]
session.add_answer(answer)
output_audio_path = "response.wav"
if enable_tts:
tts.tts_to_file(text=answer, file_path=output_audio_path)
else:
output_audio_path = None
return answer, output_audio_path