saa231 commited on
Commit
dd077b5
·
verified ·
1 Parent(s): 8cdf492

Update project_model.py

Browse files
Files changed (1) hide show
  1. project_model.py +63 -14
project_model.py CHANGED
@@ -1,3 +1,10 @@
 
 
 
 
 
 
 
1
  # project_module.py
2
 
3
  # Import libraries for ML, CV, NLP, audio, and TTS
@@ -17,8 +24,8 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
  # Load all models
19
  yolo_model = YOLO("yolov9c.pt") # YOLOv9 for object detection
20
- depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device).eval()
21
- depth_feat = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
22
 
23
  # Whisper for audio transcription
24
  whisper_pipe = pipeline(
@@ -29,7 +36,7 @@ whisper_pipe = pipeline(
29
 
30
  # GEMMA for image+text to text QA
31
  gemma_pipe = pipeline(
32
- "image-to-text",
33
  model="google/gemma-3-4b-it",
34
  device=0 if torch.cuda.is_available() else -1,
35
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
@@ -43,56 +50,86 @@ tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC")
43
  # -------------------------------
44
 
45
  class VisualQAState:
 
 
 
46
  def __init__(self):
47
  self.current_image: Image.Image = None
48
  self.visual_context: str = ""
49
  self.message_history = []
50
 
51
  def reset(self, image: Image.Image, visual_context: str):
 
 
 
 
52
  self.current_image = image
53
  self.visual_context = visual_context
54
- self.message_history = []
 
 
 
 
 
 
55
 
56
  def add_question(self, question: str):
 
 
 
 
57
  if not self.message_history or self.message_history[-1]["role"] == "assistant":
58
  self.message_history.append({
59
  "role": "user",
60
- "content": [
61
- {"type": "image", "image": self.current_image},
62
- {"type": "text", "text": question}
63
- ]
64
  })
65
 
66
  def add_answer(self, answer: str):
67
- self.message_history.append({"role": "assistant", "content": answer})
 
 
 
 
 
 
68
 
69
  # -------------------------------
70
  # Generate Context from Image
71
  # -------------------------------
72
 
73
  def generate_visual_context(pil_image: Image.Image) -> str:
 
 
 
 
 
74
  rgb_image = np.array(pil_image)
75
  cv2_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
76
 
 
77
  yolo_results = yolo_model.predict(cv2_image)[0]
78
  boxes = yolo_results.boxes
79
  class_names = yolo_model.names
80
 
 
81
  depth_inputs = depth_feat(images=pil_image, return_tensors="pt").to(device)
82
  with torch.no_grad():
83
  depth_output = depth_model(**depth_inputs)
84
  depth_map = depth_output.predicted_depth.squeeze().cpu().numpy()
85
  depth_map_resized = cv2.resize(depth_map, (rgb_image.shape[1], rgb_image.shape[0]))
86
 
 
87
  shared_visual_context = []
88
  for box in boxes:
89
  x1, y1, x2, y2 = map(int, box.xyxy[0])
90
  label = class_names[int(box.cls[0])]
91
  conf = float(box.conf[0])
92
 
 
93
  depth_crop = depth_map_resized[y1:y2, x1:x2]
94
  avg_depth = float(depth_crop.mean()) if depth_crop.size > 0 else None
95
 
 
96
  x_center = (x1 + x2) / 2
97
  pos = "left" if x_center < rgb_image.shape[1] / 3 else "right" if x_center > 2 * rgb_image.shape[1] / 3 else "center"
98
 
@@ -103,6 +140,7 @@ def generate_visual_context(pil_image: Image.Image) -> str:
103
  "position": pos
104
  })
105
 
 
106
  descriptions = []
107
  for obj in shared_visual_context:
108
  d = f"{obj['avg_depth']:.1f} units" if obj["avg_depth"] else "unknown"
@@ -116,6 +154,7 @@ def generate_visual_context(pil_image: Image.Image) -> str:
116
  # Main Multimodal Processing Function
117
  # -------------------------------
118
 
 
119
  session = VisualQAState()
120
 
121
  def process_inputs(
@@ -125,27 +164,37 @@ def process_inputs(
125
  audio_path: str = None,
126
  enable_tts: bool = True
127
  ):
 
 
 
 
 
 
 
128
  if image:
129
  visual_context = generate_visual_context(image)
130
  session.reset(image, visual_context)
131
 
 
132
  if audio_path:
133
  audio_text = whisper_pipe(audio_path)["text"]
134
  question += " " + audio_text
135
 
 
136
  session.add_question(question)
137
 
138
- prompt = f"{session.visual_context}\n\nUser Question: {question}"
139
-
140
- gemma_output = gemma_pipe(prompt, max_new_tokens=200)
141
- answer = gemma_output[0]["generated_text"]
142
 
 
143
  session.add_answer(answer)
144
 
 
145
  output_audio_path = "response.wav"
146
  if enable_tts:
147
  tts.tts_to_file(text=answer, file_path=output_audio_path)
148
  else:
149
  output_audio_path = None
150
 
151
- return answer, output_audio_path
 
1
+ is the error present here: -*- coding: utf-8 -*-
2
+ """project_model.ipynb
3
+ Automatically generated by Colab.
4
+ Original file is located at
5
+ https://colab.research.google.com/drive/1oopkA5yIlfizFuhXOPmTK7MUNh3Qasa3
6
+ """
7
+
8
  # project_module.py
9
 
10
  # Import libraries for ML, CV, NLP, audio, and TTS
 
24
 
25
  # Load all models
26
  yolo_model = YOLO("yolov9c.pt") # YOLOv9 for object detection
27
+ depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device).eval() # MiDaS for depth
28
+ depth_feat = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") # Feature extractor for depth model
29
 
30
  # Whisper for audio transcription
31
  whisper_pipe = pipeline(
 
36
 
37
  # GEMMA for image+text to text QA
38
  gemma_pipe = pipeline(
39
+ "image-text-to-text",
40
  model="google/gemma-3-4b-it",
41
  device=0 if torch.cuda.is_available() else -1,
42
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
 
50
  # -------------------------------
51
 
52
  class VisualQAState:
53
+ """
54
+ Stores the current image context and chat history for follow-up questions.
55
+ """
56
  def __init__(self):
57
  self.current_image: Image.Image = None
58
  self.visual_context: str = ""
59
  self.message_history = []
60
 
61
  def reset(self, image: Image.Image, visual_context: str):
62
+ """
63
+ Called when a new image is uploaded.
64
+ Resets context and starts new message history.
65
+ """
66
  self.current_image = image
67
  self.visual_context = visual_context
68
+ self.message_history = [{
69
+ "role": "user",
70
+ "content": [
71
+ {"type": "image", "image": self.current_image},
72
+ {"type": "text", "text": self.visual_context}
73
+ ]
74
+ }]
75
 
76
  def add_question(self, question: str):
77
+ """
78
+ Adds a follow-up question only if the last message was from assistant.
79
+ Ensures alternating user/assistant messages.
80
+ """
81
  if not self.message_history or self.message_history[-1]["role"] == "assistant":
82
  self.message_history.append({
83
  "role": "user",
84
+ "content": [{"type": "text", "text": question}]
 
 
 
85
  })
86
 
87
  def add_answer(self, answer: str):
88
+ """
89
+ Appends the assistant's response to the conversation history.
90
+ """
91
+ self.message_history.append({
92
+ "role": "assistant",
93
+ "content": [{"type": "text", "text": answer}]
94
+ })
95
 
96
  # -------------------------------
97
  # Generate Context from Image
98
  # -------------------------------
99
 
100
  def generate_visual_context(pil_image: Image.Image) -> str:
101
+ """
102
+ Processes the image to extract object labels, depth info, and locations.
103
+ Builds a natural language context description for use in prompting.
104
+ """
105
+ # Convert to OpenCV and RGB formats
106
  rgb_image = np.array(pil_image)
107
  cv2_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
108
 
109
+ # Object detection using YOLO
110
  yolo_results = yolo_model.predict(cv2_image)[0]
111
  boxes = yolo_results.boxes
112
  class_names = yolo_model.names
113
 
114
+ # Depth estimation using MiDaS
115
  depth_inputs = depth_feat(images=pil_image, return_tensors="pt").to(device)
116
  with torch.no_grad():
117
  depth_output = depth_model(**depth_inputs)
118
  depth_map = depth_output.predicted_depth.squeeze().cpu().numpy()
119
  depth_map_resized = cv2.resize(depth_map, (rgb_image.shape[1], rgb_image.shape[0]))
120
 
121
+ # Extract contextual information for each object
122
  shared_visual_context = []
123
  for box in boxes:
124
  x1, y1, x2, y2 = map(int, box.xyxy[0])
125
  label = class_names[int(box.cls[0])]
126
  conf = float(box.conf[0])
127
 
128
+ # Compute average depth of object
129
  depth_crop = depth_map_resized[y1:y2, x1:x2]
130
  avg_depth = float(depth_crop.mean()) if depth_crop.size > 0 else None
131
 
132
+ # Determine object horizontal position
133
  x_center = (x1 + x2) / 2
134
  pos = "left" if x_center < rgb_image.shape[1] / 3 else "right" if x_center > 2 * rgb_image.shape[1] / 3 else "center"
135
 
 
140
  "position": pos
141
  })
142
 
143
+ # Convert context to a readable sentence
144
  descriptions = []
145
  for obj in shared_visual_context:
146
  d = f"{obj['avg_depth']:.1f} units" if obj["avg_depth"] else "unknown"
 
154
  # Main Multimodal Processing Function
155
  # -------------------------------
156
 
157
+ # Create a global session object to persist across follow-ups
158
  session = VisualQAState()
159
 
160
  def process_inputs(
 
164
  audio_path: str = None,
165
  enable_tts: bool = True
166
  ):
167
+ """
168
+ Handles a new image upload or a follow-up question.
169
+ Combines image context, audio transcription, and text input to generate a GEMMA-based answer.
170
+ Optionally outputs audio using TTS.
171
+ """
172
+
173
+ # If new image is provided, reset session and build new context
174
  if image:
175
  visual_context = generate_visual_context(image)
176
  session.reset(image, visual_context)
177
 
178
+ # If user gave an audio clip, transcribe it and append to question
179
  if audio_path:
180
  audio_text = whisper_pipe(audio_path)["text"]
181
  question += " " + audio_text
182
 
183
+ # Append question to conversation history (only if alternating correctly)
184
  session.add_question(question)
185
 
186
+ # Generate response using GEMMA with full conversation history
187
+ gemma_output = gemma_pipe(text=session.message_history, max_new_tokens=200)
188
+ answer = gemma_output[0]["generated_text"][-1]["content"]
 
189
 
190
+ # Append GEMMA's response to the history to maintain alternating structure
191
  session.add_answer(answer)
192
 
193
+ # If TTS is enabled, synthesize answer as speech
194
  output_audio_path = "response.wav"
195
  if enable_tts:
196
  tts.tts_to_file(text=answer, file_path=output_audio_path)
197
  else:
198
  output_audio_path = None
199
 
200
+ return answer, output_audio_path