dxfoso commited on
Commit
fa371ed
·
1 Parent(s): d33203e

update display for ontology

Browse files
Files changed (2) hide show
  1. local_models.py +77 -0
  2. ui_components.py +8 -4
local_models.py CHANGED
@@ -49,6 +49,11 @@ class CNNImageCaptioner:
49
  return f"Model loading failed: {load_result}"
50
 
51
  try:
 
 
 
 
 
52
  # Prepare inputs
53
  if prompt:
54
  inputs = self.processor(image, prompt, return_tensors="pt").to(self.device)
@@ -70,6 +75,78 @@ class CNNImageCaptioner:
70
 
71
  except Exception as e:
72
  return f"Error generating caption: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
 
75
  class TransformerImageCaptioner:
 
49
  return f"Model loading failed: {load_result}"
50
 
51
  try:
52
+ # Handle counting prompts specially
53
+ if prompt and any(word in prompt.lower() for word in ['count', 'how many', 'number of']):
54
+ # For counting prompts, use better strategy
55
+ return self._handle_counting_prompt(image, prompt)
56
+
57
  # Prepare inputs
58
  if prompt:
59
  inputs = self.processor(image, prompt, return_tensors="pt").to(self.device)
 
75
 
76
  except Exception as e:
77
  return f"Error generating caption: {str(e)}"
78
+
79
+ def _handle_counting_prompt(self, image: Image.Image, original_prompt: str) -> str:
80
+ """Handle counting prompts with better strategy"""
81
+ try:
82
+ # Generate multiple descriptions
83
+ descriptions = []
84
+
85
+ # Basic scene description (no prompt - works better)
86
+ inputs_basic = self.processor(image, return_tensors="pt").to(self.device)
87
+ with torch.no_grad():
88
+ out_basic = self.model.generate(**inputs_basic, max_length=50, num_beams=4)
89
+ basic_desc = self.processor.decode(out_basic[0], skip_special_tokens=True)
90
+ descriptions.append(basic_desc)
91
+
92
+ # People-focused description
93
+ inputs_people = self.processor(image, "describe people in this image", return_tensors="pt").to(self.device)
94
+ with torch.no_grad():
95
+ out_people = self.model.generate(**inputs_people, max_length=50, num_beams=4)
96
+ people_desc = self.processor.decode(out_people[0], skip_special_tokens=True)
97
+ if people_desc.startswith("describe people in this image"):
98
+ people_desc = people_desc[len("describe people in this image"):].strip()
99
+ descriptions.append(people_desc)
100
+
101
+ # Analyze for counting
102
+ combined_text = " ".join(descriptions).lower()
103
+ count_result = self._extract_count_from_text(combined_text, original_prompt)
104
+
105
+ return count_result
106
+
107
+ except Exception as e:
108
+ return f"Counting analysis failed: {str(e)}"
109
+
110
+ def _extract_count_from_text(self, text: str, original_prompt: str) -> str:
111
+ """Extract count information from text descriptions"""
112
+ import re
113
+
114
+ # Define patterns
115
+ people_words = ['person', 'people', 'man', 'woman', 'worker', 'workers', 'individual', 'human']
116
+ number_words = {
117
+ 'one': 1, 'two': 2, 'three': 3, 'four': 4, 'five': 5,
118
+ 'a': 1, 'single': 1, 'couple': 2, 'few': 3, 'several': 4, 'many': 5
119
+ }
120
+ track_words = ['track', 'tracks', 'rail', 'rails', 'railway', 'railroad']
121
+
122
+ # Extract numbers
123
+ explicit_numbers = re.findall(r'\b(\d+)\b', text)
124
+ explicit_numbers = [int(n) for n in explicit_numbers if 1 <= int(n) <= 20]
125
+
126
+ # Count mentions
127
+ people_mentions = sum(1 for word in people_words if word in text)
128
+ track_mentions = sum(1 for word in track_words if word in text)
129
+
130
+ # Find number words
131
+ found_numbers = [num for word, num in number_words.items() if word in text]
132
+
133
+ # Determine count
134
+ estimated_count = 0
135
+ if explicit_numbers:
136
+ estimated_count = explicit_numbers[0]
137
+ elif found_numbers:
138
+ estimated_count = max(found_numbers)
139
+ elif people_mentions > 0:
140
+ estimated_count = people_mentions
141
+
142
+ # Build response
143
+ if estimated_count > 0:
144
+ if track_mentions > 0:
145
+ return f"Detected approximately {estimated_count} person{'s' if estimated_count > 1 else ''} in railway scene. Scene: {text[:100]}..."
146
+ else:
147
+ return f"Detected approximately {estimated_count} person{'s' if estimated_count > 1 else ''} in image. Scene: {text[:100]}..."
148
+ else:
149
+ return f"No clear person count detected. Scene description: {text[:150]}..."
150
 
151
 
152
  class TransformerImageCaptioner:
ui_components.py CHANGED
@@ -168,11 +168,15 @@ def render_frame_result(result_data: Dict[str, Any]):
168
  Render a single frame result with ontology analysis
169
  """
170
  ontology = result_data['ontology_analysis']
171
- severity_icon = ontology.get('severity_icon', '✅')
172
- severity = ontology.get('severity', 'NONE')
173
 
174
- # Create expander title with severity indicator
175
- expander_title = f"{severity_icon} {severity} - Frame {result_data['frame_number']} (t={result_data['timestamp']:.1f}s)"
 
 
 
 
 
 
176
 
177
  with st.expander(expander_title):
178
  col_img, col_text = st.columns([1, 2])
 
168
  Render a single frame result with ontology analysis
169
  """
170
  ontology = result_data['ontology_analysis']
 
 
171
 
172
+ # Create expander title - only include severity if ontology is active
173
+ if ontology.get('ontology_used', False):
174
+ severity_icon = ontology.get('severity_icon', '✅')
175
+ severity = ontology.get('severity', 'NONE')
176
+ expander_title = f"{severity_icon} {severity} - Frame {result_data['frame_number']} (t={result_data['timestamp']:.1f}s)"
177
+ else:
178
+ # Clean title without severity symbols when ontology is disabled
179
+ expander_title = f"Frame {result_data['frame_number']} (t={result_data['timestamp']:.1f}s)"
180
 
181
  with st.expander(expander_title):
182
  col_img, col_text = st.columns([1, 2])