IT4CHI2311 commited on
Commit
09421ab
·
1 Parent(s): e3df3aa

Made changes

Browse files
Files changed (1) hide show
  1. app.py +44 -13
app.py CHANGED
@@ -106,30 +106,61 @@ def extract_features(image, rcnn_backbone, llava_model, llava_processor):
106
  rcnn_features['pool'], (1, 1)
107
  ).flatten().cpu().numpy()
108
 
109
- # LLaVA Phi-3-Mini features (lightweight vision-language model)
110
  if USE_LLAVA and llava_model is not None:
111
- prompt = "USER: <image>\nDescribe this image in detail.\nASSISTANT:"
 
 
 
 
 
112
  inputs = llava_processor(text=prompt, images=image, return_tensors="pt")
113
  inputs = {k: v.to(device) for k, v in inputs.items()}
114
 
115
- outputs = llava_model.generate(
116
- **inputs,
117
- max_new_tokens=77,
118
- output_hidden_states=True,
119
- return_dict_in_generate=True
120
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- # Extract LLaVA features from last hidden state
123
- hidden_states = outputs.hidden_states[0][-1]
124
- llava_feat_vector = hidden_states.mean(dim=1).squeeze().cpu().numpy()
125
 
126
  # Resize to 1024 dimensions
127
  if llava_feat_vector.shape[0] != 1024:
128
  if llava_feat_vector.shape[0] < 1024:
129
- # Pad if smaller
130
  llava_feat_vector = np.pad(llava_feat_vector, (0, 1024 - llava_feat_vector.shape[0]))
131
  else:
132
- # Truncate if larger
133
  llava_feat_vector = llava_feat_vector[:1024]
134
  else:
135
  # Use zeros when LLaVA is disabled (maintains compatibility)
 
106
  rcnn_features['pool'], (1, 1)
107
  ).flatten().cpu().numpy()
108
 
109
+ # LLaVA Phi-3-Mini features (FAST - direct vision encoder, no text generation)
110
  if USE_LLAVA and llava_model is not None:
111
+ # CRITICAL: Ensure patch_size is set before processing
112
+ if hasattr(llava_processor, 'image_processor'):
113
+ llava_processor.image_processor.patch_size = 14
114
+ llava_processor.patch_size = 14
115
+
116
+ prompt = "USER: <image>\nASSISTANT:"
117
  inputs = llava_processor(text=prompt, images=image, return_tensors="pt")
118
  inputs = {k: v.to(device) for k, v in inputs.items()}
119
 
120
+ # Extract visual features directly (10-20x faster than generate())
121
+ # Get vision tower
122
+ if hasattr(llava_model, 'get_vision_tower'):
123
+ vision_tower = llava_model.get_vision_tower()
124
+ elif hasattr(llava_model, 'vision_tower'):
125
+ vision_tower = llava_model.vision_tower
126
+ else:
127
+ vision_tower = None
128
+
129
+ # Use vision tower directly if available
130
+ if vision_tower is not None and 'pixel_values' in inputs:
131
+ image_outputs = vision_tower(inputs['pixel_values'])
132
+
133
+ # Handle different output types
134
+ if hasattr(image_outputs, 'pooler_output'):
135
+ llava_feat_vector = image_outputs.pooler_output.squeeze().cpu().numpy()
136
+ elif hasattr(image_outputs, 'last_hidden_state'):
137
+ llava_feat_vector = image_outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
138
+ elif isinstance(image_outputs, tuple):
139
+ llava_feat_vector = image_outputs[0].mean(dim=1).squeeze().cpu().numpy()
140
+ else:
141
+ if image_outputs.dim() > 2:
142
+ llava_feat_vector = image_outputs.mean(dim=1).squeeze().cpu().numpy()
143
+ else:
144
+ llava_feat_vector = image_outputs.squeeze().cpu().numpy()
145
+ else:
146
+ # Fallback: use model forward pass (still much faster than generate)
147
+ outputs = llava_model(
148
+ input_ids=inputs['input_ids'],
149
+ attention_mask=inputs.get('attention_mask'),
150
+ pixel_values=inputs.get('pixel_values'),
151
+ output_hidden_states=True
152
+ )
153
+ llava_feat_vector = outputs.hidden_states[-1].mean(dim=1).squeeze().cpu().numpy()
154
 
155
+ # Ensure proper shape
156
+ if llava_feat_vector.ndim > 1:
157
+ llava_feat_vector = llava_feat_vector.flatten()
158
 
159
  # Resize to 1024 dimensions
160
  if llava_feat_vector.shape[0] != 1024:
161
  if llava_feat_vector.shape[0] < 1024:
 
162
  llava_feat_vector = np.pad(llava_feat_vector, (0, 1024 - llava_feat_vector.shape[0]))
163
  else:
 
164
  llava_feat_vector = llava_feat_vector[:1024]
165
  else:
166
  # Use zeros when LLaVA is disabled (maintains compatibility)