Update app.py
Browse files
app.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
import sys
|
| 2 |
sys.stdout.reconfigure(line_buffering=True)
|
| 3 |
|
| 4 |
-
|
| 5 |
import os
|
| 6 |
import numpy as np
|
| 7 |
import requests
|
|
@@ -39,7 +38,8 @@ device = torch.device('cuda' if 'cuda' in device_name and torch.cuda.is_availabl
|
|
| 39 |
print(f"🧠 Using device: {device}")
|
| 40 |
|
| 41 |
print("...Loading Grounding DINO model...")
|
| 42 |
-
|
|
|
|
| 43 |
processor_gnd = AutoProcessor.from_pretrained(gnd_model_id)
|
| 44 |
model_gnd = AutoModelForZeroShotObjectDetection.from_pretrained(gnd_model_id).to(device)
|
| 45 |
|
|
@@ -49,7 +49,8 @@ sam_model = sam_model_registry["vit_b"](checkpoint=sam_checkpoint).to(device)
|
|
| 49 |
predictor = SamPredictor(sam_model)
|
| 50 |
|
| 51 |
print("...Loading BGE model for text embeddings...")
|
| 52 |
-
|
|
|
|
| 53 |
tokenizer_text = AutoTokenizer.from_pretrained(bge_model_id)
|
| 54 |
model_text = AutoModel.from_pretrained(bge_model_id).to(device)
|
| 55 |
print("✅ All models loaded successfully.")
|
|
@@ -236,14 +237,14 @@ def compare_items():
|
|
| 236 |
# Text comparison is always done
|
| 237 |
text_emb_found = np.array(item['text_embedding'])
|
| 238 |
text_score = cosine_similarity(query_text_emb, text_emb_found)
|
| 239 |
-
print(f"
|
| 240 |
|
| 241 |
# --- NEW: Check if BOTH items have visual features ---
|
| 242 |
has_query_image = 'shape_features' in query_item and query_item['shape_features']
|
| 243 |
has_item_image = 'shape_features' in item and item['shape_features']
|
| 244 |
|
| 245 |
if has_query_image and has_item_image:
|
| 246 |
-
print("
|
| 247 |
# If both have images, proceed with full comparison
|
| 248 |
query_shape_feat = np.array(query_item['shape_features'])
|
| 249 |
query_color_feat = np.array(query_item['color_features']).astype("float32")
|
|
@@ -259,25 +260,25 @@ def compare_items():
|
|
| 259 |
texture_score = cv2.compareHist(query_texture_feat, found_texture, cv2.HISTCMP_CORREL)
|
| 260 |
|
| 261 |
raw_image_score = (FEATURE_WEIGHTS["shape"] * shape_score +
|
| 262 |
-
|
| 263 |
-
|
| 264 |
|
| 265 |
-
print(f"
|
| 266 |
|
| 267 |
image_score = stretch_image_score(raw_image_score)
|
| 268 |
|
| 269 |
# Weighted average of image and text scores
|
| 270 |
final_score = 0.4 * image_score + 0.6 * text_score
|
| 271 |
-
print(f"
|
| 272 |
|
| 273 |
else:
|
| 274 |
# If one or both items lack an image, the final score is JUST the text score
|
| 275 |
-
print("
|
| 276 |
final_score = text_score
|
| 277 |
|
| 278 |
# Check if the final score meets the threshold
|
| 279 |
if final_score >= FINAL_SCORE_THRESHOLD:
|
| 280 |
-
print(f"
|
| 281 |
results.append({
|
| 282 |
"_id": item_id,
|
| 283 |
"score": round(final_score, 4),
|
|
@@ -286,7 +287,7 @@ def compare_items():
|
|
| 286 |
"objectImage": item.get("objectImage"),
|
| 287 |
})
|
| 288 |
else:
|
| 289 |
-
print(f"
|
| 290 |
|
| 291 |
except Exception as e:
|
| 292 |
print(f" [Skipping] Item {item_id} due to processing error: {e}")
|
|
|
|
| 1 |
import sys
|
| 2 |
sys.stdout.reconfigure(line_buffering=True)
|
| 3 |
|
|
|
|
| 4 |
import os
|
| 5 |
import numpy as np
|
| 6 |
import requests
|
|
|
|
| 38 |
print(f"🧠 Using device: {device}")
|
| 39 |
|
| 40 |
print("...Loading Grounding DINO model...")
|
| 41 |
+
# --- ⬇️ UPGRADED MODEL ⬇️ ---
|
| 42 |
+
gnd_model_id = "IDEA-Research/grounding-dino-large"
|
| 43 |
processor_gnd = AutoProcessor.from_pretrained(gnd_model_id)
|
| 44 |
model_gnd = AutoModelForZeroShotObjectDetection.from_pretrained(gnd_model_id).to(device)
|
| 45 |
|
|
|
|
| 49 |
predictor = SamPredictor(sam_model)
|
| 50 |
|
| 51 |
print("...Loading BGE model for text embeddings...")
|
| 52 |
+
# --- ⬇️ UPGRADED MODEL ⬇️ ---
|
| 53 |
+
bge_model_id = "BAAI/bge-large-en-v1.5"
|
| 54 |
tokenizer_text = AutoTokenizer.from_pretrained(bge_model_id)
|
| 55 |
model_text = AutoModel.from_pretrained(bge_model_id).to(device)
|
| 56 |
print("✅ All models loaded successfully.")
|
|
|
|
| 237 |
# Text comparison is always done
|
| 238 |
text_emb_found = np.array(item['text_embedding'])
|
| 239 |
text_score = cosine_similarity(query_text_emb, text_emb_found)
|
| 240 |
+
print(f" - Text Score: {text_score:.4f}")
|
| 241 |
|
| 242 |
# --- NEW: Check if BOTH items have visual features ---
|
| 243 |
has_query_image = 'shape_features' in query_item and query_item['shape_features']
|
| 244 |
has_item_image = 'shape_features' in item and item['shape_features']
|
| 245 |
|
| 246 |
if has_query_image and has_item_image:
|
| 247 |
+
print(" - Both items have images. Performing visual comparison.")
|
| 248 |
# If both have images, proceed with full comparison
|
| 249 |
query_shape_feat = np.array(query_item['shape_features'])
|
| 250 |
query_color_feat = np.array(query_item['color_features']).astype("float32")
|
|
|
|
| 260 |
texture_score = cv2.compareHist(query_texture_feat, found_texture, cv2.HISTCMP_CORREL)
|
| 261 |
|
| 262 |
raw_image_score = (FEATURE_WEIGHTS["shape"] * shape_score +
|
| 263 |
+
FEATURE_WEIGHTS["color"] * color_score +
|
| 264 |
+
FEATURE_WEIGHTS["texture"] * texture_score)
|
| 265 |
|
| 266 |
+
print(f" - Raw Image Score: {raw_image_score:.4f}")
|
| 267 |
|
| 268 |
image_score = stretch_image_score(raw_image_score)
|
| 269 |
|
| 270 |
# Weighted average of image and text scores
|
| 271 |
final_score = 0.4 * image_score + 0.6 * text_score
|
| 272 |
+
print(f" - Image Score: {image_score:.4f} | Final Score: {final_score:.4f}")
|
| 273 |
|
| 274 |
else:
|
| 275 |
# If one or both items lack an image, the final score is JUST the text score
|
| 276 |
+
print(" - One or both items missing image. Using text score only.")
|
| 277 |
final_score = text_score
|
| 278 |
|
| 279 |
# Check if the final score meets the threshold
|
| 280 |
if final_score >= FINAL_SCORE_THRESHOLD:
|
| 281 |
+
print(f" - ✅ ACCEPTED (Score >= {FINAL_SCORE_THRESHOLD})")
|
| 282 |
results.append({
|
| 283 |
"_id": item_id,
|
| 284 |
"score": round(final_score, 4),
|
|
|
|
| 287 |
"objectImage": item.get("objectImage"),
|
| 288 |
})
|
| 289 |
else:
|
| 290 |
+
print(f" - ❌ REJECTED (Score < {FINAL_SCORE_THRESHOLD})")
|
| 291 |
|
| 292 |
except Exception as e:
|
| 293 |
print(f" [Skipping] Item {item_id} due to processing error: {e}")
|