File size: 8,238 Bytes
dd0f882 00aadd5 dd0f882 1bad0b9 5c537c2 9c1199b 1bad0b9 00aadd5 5c537c2 dd0f882 00aadd5 dd077b5 1bad0b9 00aadd5 5c537c2 00aadd5 dd0f882 dd077b5 dd0f882 00aadd5 5c537c2 00aadd5 5c537c2 dd077b5 5c537c2 6a15e66 5c537c2 0a232f2 dd077b5 5c537c2 6a15e66 5c537c2 f54ec9e 903dad9 7facfa4 f54ec9e 5c537c2 dd077b5 19c4411 dd077b5 19c4411 fc25a0b dd077b5 5c537c2 00aadd5 5c537c2 9ad4c7a dd077b5 5c537c2 dd0f882 9ad4c7a dd0f882 9ad4c7a dd0f882 9ad4c7a dd0f882 00aadd5 dd0f882 5c537c2 dd0f882 5c537c2 dd0f882 5c537c2 9ad4c7a 5c537c2 00aadd5 5c537c2 dd077b5 7dd46f0 5c537c2 6a15e66 d755e14 6a15e66 0a232f2 5c537c2 6023582 5c537c2 44337e5 5c537c2 8435f69 5c537c2 cfa5da5 5abea1b 7be8a09 e5bea4b 8fff8bd 7be8a09 e5bea4b b33904d e5bea4b 2826fcf 6023582 2826fcf 5abea1b 3f58745 2826fcf 8435f69 dd0f882 8435f69 fc25a0b 8435f69 dd0f882 5c537c2 dd0f882 9be4630 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
# project_module.py
# Import libraries for ML, CV, NLP, audio, and TTS
import torch, cv2, os
import numpy as np
from PIL import Image
from ultralytics import YOLO
from transformers import pipeline, DPTFeatureExtractor, DPTForDepthEstimation
from TTS.api import TTS
from huggingface_hub import login
# Authenticate to Hugging Face using environment token
login(token=os.environ["HUGGING_FACE_HUB_TOKEN"])
# Set device for computation (GPU if available)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load all models
yolo_model = YOLO("yolov9c.pt") # YOLOv9 for object detection
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device).eval() # MiDaS for depth
depth_feat = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") # Feature extractor for depth model
# Whisper for audio transcription
whisper_pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-small",
device=0 if torch.cuda.is_available() else -1
)
# GEMMA for image+text to text QA
gemma_pipe = pipeline(
"image-text-to-text",
model="google/gemma-3-4b-it",
device=0 if torch.cuda.is_available() else -1,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
)
# Text-to-speech
tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC")
# -------------------------------
# Session Management Class
# -------------------------------
class VisualQAState:
"""
Stores the current image context and chat history for follow-up questions.
"""
def __init__(self):
self.current_image: Image.Image = None
self.annotated_image: Image.Image = None
self.visual_context: str = ""
self.message_history = []
def reset(self, image: Image.Image, annotated_image: Image.Image, visual_context: str):
"""
Called when a new image is uploaded.
Resets context and starts new message history.
"""
self.current_image = image
self.annotated_image = annotated_image
self.visual_context = visual_context
self.message_history = [
{
"role": "system", # System prompt
"content": (
"You are a helpful visual assistant designed for visually impaired users that assists users by answering their questions. "
#"You must provide detailed, descriptive, and spatially-aware answers based on the given image, the question asked, and conversation history. "
#"Always describe what you see clearly and help the user understand the scene. "
'If unsure, say "I am not certain."'
)
},
{
"role": "user", # The user input
"content": [
{"type": "image", "image": self.current_image}, # Image context
{"type": "text", "text": self.visual_context} # Visual context description
]
}
]
def add_question(self, question: str):
"""
Adds a follow-up question only if the last message was from assistant.
Ensures alternating user/assistant messages.
"""
if not self.message_history or self.message_history[-1]["role"] == "assistant":
self.message_history.append({
"role": "user",
"content": [{"type": "text", "text": question}]
})
def add_answer(self, answer: str):
"""
Appends the assistant's response to the conversation history.
"""
self.message_history.append({
"role": "assistant",
"content": [{"type": "text", "text": answer}]
})
# -------------------------------
# Generate Context from Image
# -------------------------------
def generate_visual_context(pil_image: Image.Image):
# Convert to OpenCV and RGB formats
rgb_image = np.array(pil_image)
cv2_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
# Object detection
yolo_results = yolo_model.predict(cv2_image)[0]
boxes = yolo_results.boxes
class_names = yolo_model.names
# Depth estimation
depth_inputs = depth_feat(images=pil_image, return_tensors="pt").to(device)
with torch.no_grad():
depth_output = depth_model(**depth_inputs)
depth_map = depth_output.predicted_depth.squeeze().cpu().numpy()
depth_map_resized = cv2.resize(depth_map, (rgb_image.shape[1], rgb_image.shape[0]))
# Draw bounding boxes on a copy of the image
annotated_image = cv2_image.copy()
for box in boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
label = class_names[int(box.cls[0])]
conf = float(box.conf[0])
cv2.rectangle(annotated_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(annotated_image, f"{label} {conf:.2f}", (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
# Extract context
shared_visual_context = []
for box in boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
label = class_names[int(box.cls[0])]
conf = float(box.conf[0])
depth_crop = depth_map_resized[y1:y2, x1:x2]
avg_depth = float(depth_crop.mean()) if depth_crop.size > 0 else None
x_center = (x1 + x2) / 2
pos = "left" if x_center < rgb_image.shape[1] / 3 else "right" if x_center > 2 * rgb_image.shape[1] / 3 else "center"
shared_visual_context.append({
"label": label,
"confidence": conf,
"avg_depth": avg_depth,
"position": pos
})
descriptions = []
for obj in shared_visual_context:
d = f"{obj['avg_depth']:.1f} units" if obj["avg_depth"] else "unknown"
s = obj.get("position", "unknown")
c = obj.get("confidence", 0.0)
descriptions.append(f"a {obj['label']} ({c:.2f} confidence) is at {d} on the {s}")
context_sentence = "In the image, " + ", ".join(descriptions) + "."
# Save annotated image
annotated_pil = Image.fromarray(cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB))
return context_sentence, annotated_pil
# -------------------------------
# Main Multimodal Processing Function
# -------------------------------
# Create a global session object to persist across follow-ups
session = VisualQAState()
def process_inputs(
session: VisualQAState,
image: Image.Image = None,
question: str = "",
audio_path: str = None,
enable_tts: bool = True
):
if image:
# Generate visual context and annotated image
visual_context, annotated_image = generate_visual_context(image)
# Reset session with the current image and visual context
session.reset(image, annotated_image, visual_context)
if audio_path:
# Process audio to text
audio_text = whisper_pipe(audio_path)["text"]
question += ' ' + audio_text.strip()
# Add user's new question to the history
session.add_question(question)
vqa_prompt = "You are a helpful visual assistant designed for visually impaired users that assists users by answering their questions. Answer the following question with the help of the shared visual context: " + question + "Shared visual context: " + visual_context
# Gemma prompt input
messages = [{
"role": "user",
"content": [
{"type": "image", "image": session.current_image},
{"type": "text", "text": vqa_prompt}]
}]
# Call to gemma_pipe
gemma_output = gemma_pipe(text=messages, max_new_tokens=500)
# Handle the output from Gemma model safely
if isinstance(gemma_output, list) and len(gemma_output) > 0:
gemma_text = gemma_output[0]["generated_text"][-1]["content"]
if isinstance(gemma_text, str):
answer = gemma_text
else:
answer = "No valid output from Gemma model."
# Save assistant's answer into session history
session.add_answer(answer)
# Text-to-speech output
output_audio_path = "response.wav"
if enable_tts:
tts.tts_to_file(text=answer, file_path=output_audio_path)
else:
output_audio_path = None
return answer, output_audio_path
|