saa231 commited on
Commit
00aadd5
·
verified ·
1 Parent(s): 8ad7a6b

Update project_model.py

Browse files
Files changed (1) hide show
  1. project_model.py +21 -58
project_model.py CHANGED
@@ -1,5 +1,7 @@
1
  # project_module.py
2
- import torch, cv2, os, time
 
 
3
  import numpy as np
4
  from PIL import Image
5
  from ultralytics import YOLO
@@ -10,96 +12,55 @@ from huggingface_hub import login
10
  # Authenticate to Hugging Face using environment token
11
  login(token=os.environ["HUGGING_FACE_HUB_TOKEN"])
12
 
13
- # Set device for computation
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
- # Load models
17
- yolo_model = YOLO("yolov9c.pt")
18
  depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device).eval()
19
  depth_feat = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
20
 
 
21
  whisper_pipe = pipeline(
22
  "automatic-speech-recognition",
23
  model="openai/whisper-small",
24
  device=0 if torch.cuda.is_available() else -1
25
  )
26
 
 
27
  gemma_pipe = pipeline(
28
- "image-text-to-text",
29
  model="google/gemma-3-4b-it",
30
  device=0 if torch.cuda.is_available() else -1,
31
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
32
  )
33
 
 
34
  tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC")
35
 
36
  # -------------------------------
37
- # Smart Session Management
38
  # -------------------------------
39
 
40
  class VisualQAState:
41
- TIMEOUT_SECONDS = 60
42
-
43
  def __init__(self):
44
  self.current_image: Image.Image = None
45
  self.visual_context: str = ""
46
  self.message_history = []
47
- self.last_interaction_time = time.time()
48
-
49
- def _check_timeout(self):
50
- if time.time() - self.last_interaction_time > self.TIMEOUT_SECONDS:
51
- self.soft_reset()
52
-
53
- def soft_reset(self):
54
- self.message_history = [{
55
- "role": "user",
56
- "content": [
57
- {"type": "image", "image": self.current_image},
58
- {"type": "text", "text": self.visual_context}
59
- ]
60
- }]
61
- print("🔄 Session timed out: soft reset applied.")
62
 
63
  def reset(self, image: Image.Image, visual_context: str):
64
  self.current_image = image
65
  self.visual_context = visual_context
66
- self.message_history = [{
67
- "role": "user",
68
- "content": [
69
- {"type": "image", "image": self.current_image},
70
- {"type": "text", "text": self.visual_context}
71
- ]
72
- }]
73
- self.last_interaction_time = time.time()
74
 
75
  def add_question(self, question: str):
76
- self._check_timeout()
77
- if not self.message_history or self.message_history[-1]["role"] == "assistant":
78
- self.message_history.append({
79
- "role": "user",
80
- "content": [{"type": "text", "text": question}]
81
- })
82
- self.last_interaction_time = time.time()
83
 
84
  def add_answer(self, answer: str):
85
- self._check_timeout()
86
- self.message_history.append({
87
- "role": "assistant",
88
- "content": [{"type": "text", "text": answer}]
89
- })
90
- self.last_interaction_time = time.time()
91
-
92
- def export_transcript(self) -> str:
93
- transcript = []
94
- for turn in self.message_history:
95
- role = turn["role"].capitalize()
96
- for entry in turn["content"]:
97
- if entry["type"] == "text":
98
- transcript.append(f"{role}: {entry['text']}")
99
- return "\n\n".join(transcript)
100
 
101
  # -------------------------------
102
- # Image Context Generation
103
  # -------------------------------
104
 
105
  def generate_visual_context(pil_image: Image.Image) -> str:
@@ -121,6 +82,7 @@ def generate_visual_context(pil_image: Image.Image) -> str:
121
  x1, y1, x2, y2 = map(int, box.xyxy[0])
122
  label = class_names[int(box.cls[0])]
123
  conf = float(box.conf[0])
 
124
  depth_crop = depth_map_resized[y1:y2, x1:x2]
125
  avg_depth = float(depth_crop.mean()) if depth_crop.size > 0 else None
126
 
@@ -144,7 +106,7 @@ def generate_visual_context(pil_image: Image.Image) -> str:
144
  return "In the image, " + ", ".join(descriptions) + "."
145
 
146
  # -------------------------------
147
- # Main Processing
148
  # -------------------------------
149
 
150
  session = VisualQAState()
@@ -166,8 +128,10 @@ def process_inputs(
166
 
167
  session.add_question(question)
168
 
169
- gemma_output = gemma_pipe(text=session.message_history, max_new_tokens=200)
170
- answer = gemma_output[0]["generated_text"][-1]["content"]
 
 
171
 
172
  session.add_answer(answer)
173
 
@@ -178,4 +142,3 @@ def process_inputs(
178
  output_audio_path = None
179
 
180
  return answer, output_audio_path
181
-
 
1
  # project_module.py
2
+
3
+ # Import libraries for ML, CV, NLP, audio, and TTS
4
+ import torch, cv2, os
5
  import numpy as np
6
  from PIL import Image
7
  from ultralytics import YOLO
 
12
  # Authenticate to Hugging Face using environment token
13
  login(token=os.environ["HUGGING_FACE_HUB_TOKEN"])
14
 
15
+ # Set device for computation (GPU if available)
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
+ # Load all models
19
+ yolo_model = YOLO("yolov9c.pt") # YOLOv9 for object detection
20
  depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device).eval()
21
  depth_feat = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
22
 
23
+ # Whisper for audio transcription
24
  whisper_pipe = pipeline(
25
  "automatic-speech-recognition",
26
  model="openai/whisper-small",
27
  device=0 if torch.cuda.is_available() else -1
28
  )
29
 
30
+ # GEMMA for image+text to text QA
31
  gemma_pipe = pipeline(
32
+ "image-to-text",
33
  model="google/gemma-3-4b-it",
34
  device=0 if torch.cuda.is_available() else -1,
35
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
36
  )
37
 
38
+ # Text-to-speech
39
  tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC")
40
 
41
  # -------------------------------
42
+ # Session Management Class
43
  # -------------------------------
44
 
45
  class VisualQAState:
 
 
46
  def __init__(self):
47
  self.current_image: Image.Image = None
48
  self.visual_context: str = ""
49
  self.message_history = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def reset(self, image: Image.Image, visual_context: str):
52
  self.current_image = image
53
  self.visual_context = visual_context
54
+ self.message_history = []
 
 
 
 
 
 
 
55
 
56
  def add_question(self, question: str):
57
+ self.message_history.append({"role": "user", "content": question})
 
 
 
 
 
 
58
 
59
  def add_answer(self, answer: str):
60
+ self.message_history.append({"role": "assistant", "content": answer})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  # -------------------------------
63
+ # Generate Context from Image
64
  # -------------------------------
65
 
66
  def generate_visual_context(pil_image: Image.Image) -> str:
 
82
  x1, y1, x2, y2 = map(int, box.xyxy[0])
83
  label = class_names[int(box.cls[0])]
84
  conf = float(box.conf[0])
85
+
86
  depth_crop = depth_map_resized[y1:y2, x1:x2]
87
  avg_depth = float(depth_crop.mean()) if depth_crop.size > 0 else None
88
 
 
106
  return "In the image, " + ", ".join(descriptions) + "."
107
 
108
  # -------------------------------
109
+ # Main Multimodal Processing Function
110
  # -------------------------------
111
 
112
  session = VisualQAState()
 
128
 
129
  session.add_question(question)
130
 
131
+ prompt = f"{session.visual_context}\n\nUser Question: {question}"
132
+
133
+ gemma_output = gemma_pipe(prompt, max_new_tokens=200)
134
+ answer = gemma_output[0]["generated_text"]
135
 
136
  session.add_answer(answer)
137
 
 
142
  output_audio_path = None
143
 
144
  return answer, output_audio_path