saa231 commited on
Commit
8ad7a6b
·
verified ·
1 Parent(s): aac3d9d

Update project_model.py

Browse files
Files changed (1) hide show
  1. project_model.py +41 -62
project_model.py CHANGED
@@ -1,16 +1,5 @@
1
- # -*- coding: utf-8 -*-
2
- """project_model.ipynb
3
-
4
- Automatically generated by Colab.
5
-
6
- Original file is located at
7
- https://colab.research.google.com/drive/1oopkA5yIlfizFuhXOPmTK7MUNh3Qasa3
8
- """
9
-
10
  # project_module.py
11
-
12
- # Import libraries for ML, CV, NLP, audio, and TTS
13
- import torch, cv2, os
14
  import numpy as np
15
  from PIL import Image
16
  from ultralytics import YOLO
@@ -21,22 +10,20 @@ from huggingface_hub import login
21
  # Authenticate to Hugging Face using environment token
22
  login(token=os.environ["HUGGING_FACE_HUB_TOKEN"])
23
 
24
- # Set device for computation (GPU if available)
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
 
27
- # Load all models
28
- yolo_model = YOLO("yolov9c.pt") # YOLOv9 for object detection
29
- depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device).eval() # MiDaS for depth
30
- depth_feat = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") # Feature extractor for depth model
31
 
32
- # Whisper for audio transcription
33
  whisper_pipe = pipeline(
34
  "automatic-speech-recognition",
35
  model="openai/whisper-small",
36
  device=0 if torch.cuda.is_available() else -1
37
  )
38
 
39
- # GEMMA for image+text to text QA
40
  gemma_pipe = pipeline(
41
  "image-text-to-text",
42
  model="google/gemma-3-4b-it",
@@ -44,27 +31,36 @@ gemma_pipe = pipeline(
44
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
45
  )
46
 
47
- # Text-to-speech
48
  tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC")
49
 
50
  # -------------------------------
51
- # Session Management Class
52
  # -------------------------------
53
 
54
  class VisualQAState:
55
- """
56
- Stores the current image context and chat history for follow-up questions.
57
- """
58
  def __init__(self):
59
  self.current_image: Image.Image = None
60
  self.visual_context: str = ""
61
  self.message_history = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  def reset(self, image: Image.Image, visual_context: str):
64
- """
65
- Called when a new image is uploaded.
66
- Resets context and starts new message history.
67
- """
68
  self.current_image = image
69
  self.visual_context = visual_context
70
  self.message_history = [{
@@ -74,64 +70,60 @@ class VisualQAState:
74
  {"type": "text", "text": self.visual_context}
75
  ]
76
  }]
 
77
 
78
  def add_question(self, question: str):
79
- """
80
- Adds a follow-up question only if the last message was from assistant.
81
- Ensures alternating user/assistant messages.
82
- """
83
  if not self.message_history or self.message_history[-1]["role"] == "assistant":
84
  self.message_history.append({
85
  "role": "user",
86
  "content": [{"type": "text", "text": question}]
87
  })
 
88
 
89
  def add_answer(self, answer: str):
90
- """
91
- Appends the assistant's response to the conversation history.
92
- """
93
  self.message_history.append({
94
  "role": "assistant",
95
  "content": [{"type": "text", "text": answer}]
96
  })
 
 
 
 
 
 
 
 
 
 
97
 
98
  # -------------------------------
99
- # Generate Context from Image
100
  # -------------------------------
101
 
102
  def generate_visual_context(pil_image: Image.Image) -> str:
103
- """
104
- Processes the image to extract object labels, depth info, and locations.
105
- Builds a natural language context description for use in prompting.
106
- """
107
- # Convert to OpenCV and RGB formats
108
  rgb_image = np.array(pil_image)
109
  cv2_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
110
 
111
- # Object detection using YOLO
112
  yolo_results = yolo_model.predict(cv2_image)[0]
113
  boxes = yolo_results.boxes
114
  class_names = yolo_model.names
115
 
116
- # Depth estimation using MiDaS
117
  depth_inputs = depth_feat(images=pil_image, return_tensors="pt").to(device)
118
  with torch.no_grad():
119
  depth_output = depth_model(**depth_inputs)
120
  depth_map = depth_output.predicted_depth.squeeze().cpu().numpy()
121
  depth_map_resized = cv2.resize(depth_map, (rgb_image.shape[1], rgb_image.shape[0]))
122
 
123
- # Extract contextual information for each object
124
  shared_visual_context = []
125
  for box in boxes:
126
  x1, y1, x2, y2 = map(int, box.xyxy[0])
127
  label = class_names[int(box.cls[0])]
128
  conf = float(box.conf[0])
129
-
130
- # Compute average depth of object
131
  depth_crop = depth_map_resized[y1:y2, x1:x2]
132
  avg_depth = float(depth_crop.mean()) if depth_crop.size > 0 else None
133
 
134
- # Determine object horizontal position
135
  x_center = (x1 + x2) / 2
136
  pos = "left" if x_center < rgb_image.shape[1] / 3 else "right" if x_center > 2 * rgb_image.shape[1] / 3 else "center"
137
 
@@ -142,7 +134,6 @@ def generate_visual_context(pil_image: Image.Image) -> str:
142
  "position": pos
143
  })
144
 
145
- # Convert context to a readable sentence
146
  descriptions = []
147
  for obj in shared_visual_context:
148
  d = f"{obj['avg_depth']:.1f} units" if obj["avg_depth"] else "unknown"
@@ -153,10 +144,9 @@ def generate_visual_context(pil_image: Image.Image) -> str:
153
  return "In the image, " + ", ".join(descriptions) + "."
154
 
155
  # -------------------------------
156
- # Main Multimodal Processing Function
157
  # -------------------------------
158
 
159
- # Create a global session object to persist across follow-ups
160
  session = VisualQAState()
161
 
162
  def process_inputs(
@@ -166,33 +156,21 @@ def process_inputs(
166
  audio_path: str = None,
167
  enable_tts: bool = True
168
  ):
169
- """
170
- Handles a new image upload or a follow-up question.
171
- Combines image context, audio transcription, and text input to generate a GEMMA-based answer.
172
- Optionally outputs audio using TTS.
173
- """
174
-
175
- # If new image is provided, reset session and build new context
176
  if image:
177
  visual_context = generate_visual_context(image)
178
  session.reset(image, visual_context)
179
 
180
- # If user gave an audio clip, transcribe it and append to question
181
  if audio_path:
182
  audio_text = whisper_pipe(audio_path)["text"]
183
  question += " " + audio_text
184
 
185
- # Append question to conversation history (only if alternating correctly)
186
  session.add_question(question)
187
 
188
- # Generate response using GEMMA with full conversation history
189
  gemma_output = gemma_pipe(text=session.message_history, max_new_tokens=200)
190
  answer = gemma_output[0]["generated_text"][-1]["content"]
191
 
192
- # Append GEMMA's response to the history to maintain alternating structure
193
  session.add_answer(answer)
194
 
195
- # If TTS is enabled, synthesize answer as speech
196
  output_audio_path = "response.wav"
197
  if enable_tts:
198
  tts.tts_to_file(text=answer, file_path=output_audio_path)
@@ -200,3 +178,4 @@ def process_inputs(
200
  output_audio_path = None
201
 
202
  return answer, output_audio_path
 
 
 
 
 
 
 
 
 
 
 
1
  # project_module.py
2
+ import torch, cv2, os, time
 
 
3
  import numpy as np
4
  from PIL import Image
5
  from ultralytics import YOLO
 
10
  # Authenticate to Hugging Face using environment token
11
  login(token=os.environ["HUGGING_FACE_HUB_TOKEN"])
12
 
13
+ # Set device for computation
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
+ # Load models
17
+ yolo_model = YOLO("yolov9c.pt")
18
+ depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device).eval()
19
+ depth_feat = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
20
 
 
21
  whisper_pipe = pipeline(
22
  "automatic-speech-recognition",
23
  model="openai/whisper-small",
24
  device=0 if torch.cuda.is_available() else -1
25
  )
26
 
 
27
  gemma_pipe = pipeline(
28
  "image-text-to-text",
29
  model="google/gemma-3-4b-it",
 
31
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
32
  )
33
 
 
34
  tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC")
35
 
36
  # -------------------------------
37
+ # Smart Session Management
38
  # -------------------------------
39
 
40
  class VisualQAState:
41
+ TIMEOUT_SECONDS = 60
42
+
 
43
  def __init__(self):
44
  self.current_image: Image.Image = None
45
  self.visual_context: str = ""
46
  self.message_history = []
47
+ self.last_interaction_time = time.time()
48
+
49
+ def _check_timeout(self):
50
+ if time.time() - self.last_interaction_time > self.TIMEOUT_SECONDS:
51
+ self.soft_reset()
52
+
53
+ def soft_reset(self):
54
+ self.message_history = [{
55
+ "role": "user",
56
+ "content": [
57
+ {"type": "image", "image": self.current_image},
58
+ {"type": "text", "text": self.visual_context}
59
+ ]
60
+ }]
61
+ print("🔄 Session timed out: soft reset applied.")
62
 
63
  def reset(self, image: Image.Image, visual_context: str):
 
 
 
 
64
  self.current_image = image
65
  self.visual_context = visual_context
66
  self.message_history = [{
 
70
  {"type": "text", "text": self.visual_context}
71
  ]
72
  }]
73
+ self.last_interaction_time = time.time()
74
 
75
  def add_question(self, question: str):
76
+ self._check_timeout()
 
 
 
77
  if not self.message_history or self.message_history[-1]["role"] == "assistant":
78
  self.message_history.append({
79
  "role": "user",
80
  "content": [{"type": "text", "text": question}]
81
  })
82
+ self.last_interaction_time = time.time()
83
 
84
  def add_answer(self, answer: str):
85
+ self._check_timeout()
 
 
86
  self.message_history.append({
87
  "role": "assistant",
88
  "content": [{"type": "text", "text": answer}]
89
  })
90
+ self.last_interaction_time = time.time()
91
+
92
+ def export_transcript(self) -> str:
93
+ transcript = []
94
+ for turn in self.message_history:
95
+ role = turn["role"].capitalize()
96
+ for entry in turn["content"]:
97
+ if entry["type"] == "text":
98
+ transcript.append(f"{role}: {entry['text']}")
99
+ return "\n\n".join(transcript)
100
 
101
  # -------------------------------
102
+ # Image Context Generation
103
  # -------------------------------
104
 
105
  def generate_visual_context(pil_image: Image.Image) -> str:
 
 
 
 
 
106
  rgb_image = np.array(pil_image)
107
  cv2_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
108
 
 
109
  yolo_results = yolo_model.predict(cv2_image)[0]
110
  boxes = yolo_results.boxes
111
  class_names = yolo_model.names
112
 
 
113
  depth_inputs = depth_feat(images=pil_image, return_tensors="pt").to(device)
114
  with torch.no_grad():
115
  depth_output = depth_model(**depth_inputs)
116
  depth_map = depth_output.predicted_depth.squeeze().cpu().numpy()
117
  depth_map_resized = cv2.resize(depth_map, (rgb_image.shape[1], rgb_image.shape[0]))
118
 
 
119
  shared_visual_context = []
120
  for box in boxes:
121
  x1, y1, x2, y2 = map(int, box.xyxy[0])
122
  label = class_names[int(box.cls[0])]
123
  conf = float(box.conf[0])
 
 
124
  depth_crop = depth_map_resized[y1:y2, x1:x2]
125
  avg_depth = float(depth_crop.mean()) if depth_crop.size > 0 else None
126
 
 
127
  x_center = (x1 + x2) / 2
128
  pos = "left" if x_center < rgb_image.shape[1] / 3 else "right" if x_center > 2 * rgb_image.shape[1] / 3 else "center"
129
 
 
134
  "position": pos
135
  })
136
 
 
137
  descriptions = []
138
  for obj in shared_visual_context:
139
  d = f"{obj['avg_depth']:.1f} units" if obj["avg_depth"] else "unknown"
 
144
  return "In the image, " + ", ".join(descriptions) + "."
145
 
146
  # -------------------------------
147
+ # Main Processing
148
  # -------------------------------
149
 
 
150
  session = VisualQAState()
151
 
152
  def process_inputs(
 
156
  audio_path: str = None,
157
  enable_tts: bool = True
158
  ):
 
 
 
 
 
 
 
159
  if image:
160
  visual_context = generate_visual_context(image)
161
  session.reset(image, visual_context)
162
 
 
163
  if audio_path:
164
  audio_text = whisper_pipe(audio_path)["text"]
165
  question += " " + audio_text
166
 
 
167
  session.add_question(question)
168
 
 
169
  gemma_output = gemma_pipe(text=session.message_history, max_new_tokens=200)
170
  answer = gemma_output[0]["generated_text"][-1]["content"]
171
 
 
172
  session.add_answer(answer)
173
 
 
174
  output_audio_path = "response.wav"
175
  if enable_tts:
176
  tts.tts_to_file(text=answer, file_path=output_audio_path)
 
178
  output_audio_path = None
179
 
180
  return answer, output_audio_path
181
+