Spaces:
Sleeping
Sleeping
Commit ·
09421ab
1
Parent(s): e3df3aa
Made changes
Browse files
app.py
CHANGED
|
@@ -106,30 +106,61 @@ def extract_features(image, rcnn_backbone, llava_model, llava_processor):
|
|
| 106 |
rcnn_features['pool'], (1, 1)
|
| 107 |
).flatten().cpu().numpy()
|
| 108 |
|
| 109 |
-
# LLaVA Phi-3-Mini features (
|
| 110 |
if USE_LLAVA and llava_model is not None:
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
inputs = llava_processor(text=prompt, images=image, return_tensors="pt")
|
| 113 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
-
#
|
| 123 |
-
|
| 124 |
-
|
| 125 |
|
| 126 |
# Resize to 1024 dimensions
|
| 127 |
if llava_feat_vector.shape[0] != 1024:
|
| 128 |
if llava_feat_vector.shape[0] < 1024:
|
| 129 |
-
# Pad if smaller
|
| 130 |
llava_feat_vector = np.pad(llava_feat_vector, (0, 1024 - llava_feat_vector.shape[0]))
|
| 131 |
else:
|
| 132 |
-
# Truncate if larger
|
| 133 |
llava_feat_vector = llava_feat_vector[:1024]
|
| 134 |
else:
|
| 135 |
# Use zeros when LLaVA is disabled (maintains compatibility)
|
|
|
|
| 106 |
rcnn_features['pool'], (1, 1)
|
| 107 |
).flatten().cpu().numpy()
|
| 108 |
|
| 109 |
+
# LLaVA Phi-3-Mini features (FAST - direct vision encoder, no text generation)
|
| 110 |
if USE_LLAVA and llava_model is not None:
|
| 111 |
+
# CRITICAL: Ensure patch_size is set before processing
|
| 112 |
+
if hasattr(llava_processor, 'image_processor'):
|
| 113 |
+
llava_processor.image_processor.patch_size = 14
|
| 114 |
+
llava_processor.patch_size = 14
|
| 115 |
+
|
| 116 |
+
prompt = "USER: <image>\nASSISTANT:"
|
| 117 |
inputs = llava_processor(text=prompt, images=image, return_tensors="pt")
|
| 118 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 119 |
|
| 120 |
+
# Extract visual features directly (10-20x faster than generate())
|
| 121 |
+
# Get vision tower
|
| 122 |
+
if hasattr(llava_model, 'get_vision_tower'):
|
| 123 |
+
vision_tower = llava_model.get_vision_tower()
|
| 124 |
+
elif hasattr(llava_model, 'vision_tower'):
|
| 125 |
+
vision_tower = llava_model.vision_tower
|
| 126 |
+
else:
|
| 127 |
+
vision_tower = None
|
| 128 |
+
|
| 129 |
+
# Use vision tower directly if available
|
| 130 |
+
if vision_tower is not None and 'pixel_values' in inputs:
|
| 131 |
+
image_outputs = vision_tower(inputs['pixel_values'])
|
| 132 |
+
|
| 133 |
+
# Handle different output types
|
| 134 |
+
if hasattr(image_outputs, 'pooler_output'):
|
| 135 |
+
llava_feat_vector = image_outputs.pooler_output.squeeze().cpu().numpy()
|
| 136 |
+
elif hasattr(image_outputs, 'last_hidden_state'):
|
| 137 |
+
llava_feat_vector = image_outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
|
| 138 |
+
elif isinstance(image_outputs, tuple):
|
| 139 |
+
llava_feat_vector = image_outputs[0].mean(dim=1).squeeze().cpu().numpy()
|
| 140 |
+
else:
|
| 141 |
+
if image_outputs.dim() > 2:
|
| 142 |
+
llava_feat_vector = image_outputs.mean(dim=1).squeeze().cpu().numpy()
|
| 143 |
+
else:
|
| 144 |
+
llava_feat_vector = image_outputs.squeeze().cpu().numpy()
|
| 145 |
+
else:
|
| 146 |
+
# Fallback: use model forward pass (still much faster than generate)
|
| 147 |
+
outputs = llava_model(
|
| 148 |
+
input_ids=inputs['input_ids'],
|
| 149 |
+
attention_mask=inputs.get('attention_mask'),
|
| 150 |
+
pixel_values=inputs.get('pixel_values'),
|
| 151 |
+
output_hidden_states=True
|
| 152 |
+
)
|
| 153 |
+
llava_feat_vector = outputs.hidden_states[-1].mean(dim=1).squeeze().cpu().numpy()
|
| 154 |
|
| 155 |
+
# Ensure proper shape
|
| 156 |
+
if llava_feat_vector.ndim > 1:
|
| 157 |
+
llava_feat_vector = llava_feat_vector.flatten()
|
| 158 |
|
| 159 |
# Resize to 1024 dimensions
|
| 160 |
if llava_feat_vector.shape[0] != 1024:
|
| 161 |
if llava_feat_vector.shape[0] < 1024:
|
|
|
|
| 162 |
llava_feat_vector = np.pad(llava_feat_vector, (0, 1024 - llava_feat_vector.shape[0]))
|
| 163 |
else:
|
|
|
|
| 164 |
llava_feat_vector = llava_feat_vector[:1024]
|
| 165 |
else:
|
| 166 |
# Use zeros when LLaVA is disabled (maintains compatibility)
|