MutimodalVisionAssistant / project_model.py
saa231's picture
Update project_model.py
9c1199b verified
raw
history blame
3.87 kB
# -*- coding: utf-8 -*-
"""project_model.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1oopkA5yIlfizFuhXOPmTK7MUNh3Qasa3
"""
# project_module.py
import torch, cv2, time, os
import numpy as np
from PIL import Image
from ultralytics import YOLO
from transformers import pipeline, DPTFeatureExtractor, DPTForDepthEstimation
from TTS.api import TTS
from huggingface_hub import login
import os
# Login using token stored in environment variable
login(token=os.environ["HUGGING_FACE_HUB_TOKEN"])
# Load models
device = "cuda" if torch.cuda.is_available() else "cpu" # Enable GPU
yolo_model = YOLO("yolov9c.pt") # Load YOLOv9
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device).eval() # Load MiDaS
depth_feat = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-small", device=0 if torch.cuda.is_available() else -1) # Load Whisper
# Load Gemma-3-4B
gemma_pipe = pipeline(
"image-text-to-text",
model="google/gemma-3-4b-it",
device=0 if torch.cuda.is_available() else -1,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
)
tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC") # Load Text-to-Speech (TTS)
# Function to process image and audio
def process_inputs(image: Image.Image, audio_path: str):
# Convert PIL image to OpenCV format
rgb_image = np.array(image)
cv2_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
pil_image = image
# YOLO Detection
yolo_results = yolo_model.predict(cv2_image)[0]
boxes = yolo_results.boxes
class_names = yolo_model.names
# MiDaS Depth
depth_inputs = depth_feat(images=pil_image, return_tensors="pt").to(device)
with torch.no_grad():
depth_output = depth_model(**depth_inputs)
depth_map = depth_output.predicted_depth.squeeze().cpu().numpy()
depth_map_resized = cv2.resize(depth_map, (rgb_image.shape[1], rgb_image.shape[0]))
# Visual Context
shared_visual_context = []
for box in boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
label = class_names[int(box.cls[0])]
conf = float(box.conf[0])
depth_crop = depth_map_resized[y1:y2, x1:x2]
avg_depth = float(depth_crop.mean()) if depth_crop.size > 0 else None
x_center = (x1 + x2) / 2
pos = "left" if x_center < rgb_image.shape[1] / 3 else "right" if x_center > 2 * rgb_image.shape[1] / 3 else "center"
shared_visual_context.append({
"label": label,
"confidence": conf,
"avg_depth": avg_depth,
"position": pos
})
# Build Context Text
def build_context_description(context):
descriptions = []
for obj in context:
d = f"{obj['avg_depth']:.1f} units" if obj["avg_depth"] else "unknown"
s = obj.get("position", "unknown")
c = obj.get("confidence", 0.0)
descriptions.append(f"a {obj['label']} ({c:.2f} confidence) is at {d} on the {s}")
return "In the image, " + ", ".join(descriptions) + "."
context_text = build_context_description(shared_visual_context)
# Transcribe audio
transcription = whisper_pipe(audio_path)["text"]
vqa_prompt = context_text + " " + transcription
# GEMMA answer
messages = [{
"role": "user",
"content": [
{"type": "image", "image": pil_image},
{"type": "text", "text": vqa_prompt}
]
}]
gemma_output = gemma_pipe(text=messages, max_new_tokens=200)
answer = gemma_output[0]["generated_text"][-1]["content"]
# Generate speech
output_audio_path = "response.wav"
tts.tts_to_file(text=answer, file_path=output_audio_path)
return answer, output_audio_path