samwaugh commited on
Commit
f5bc4f1
Β·
1 Parent(s): 0ba12ad

Logging for new embeddings in inference.py

Browse files
Files changed (1) hide show
  1. backend/runner/inference.py +39 -2
backend/runner/inference.py CHANGED
@@ -385,8 +385,45 @@ def run_inference(
385
  print(f"πŸ” Loading and preprocessing image: {image_path}")
386
  image = Image.open(image_path).convert("RGB")
387
  print(f"βœ… Image loaded successfully, size: {image.size}")
388
-
389
- # Continue with the rest of the function...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
  except Exception as e:
392
  print(f"❌ Error in run_inference: {e}")
 
385
  print(f"πŸ” Loading and preprocessing image: {image_path}")
386
  image = Image.open(image_path).convert("RGB")
387
  print(f"βœ… Image loaded successfully, size: {image.size}")
388
+
389
+ # Compute image embedding
390
+ inputs = processor(images=image, return_tensors="pt")
391
+ inputs = {k: v.to(device) for k, v in inputs.items()}
392
+
393
+ with torch.no_grad():
394
+ image_features = model.get_image_features(**inputs)
395
+ image_embedding = F.normalize(image_features.squeeze(0), dim=-1)
396
+
397
+ # Normalize sentence embeddings and compute similarities
398
+ sentence_embeddings = F.normalize(filtered_embeddings.to(device), dim=-1)
399
+ similarities = torch.matmul(sentence_embeddings, image_embedding).cpu()
400
+
401
+ # Get top-K results
402
+ k = min(top_k, len(similarities))
403
+ top_scores, top_indices = torch.topk(similarities, k=k)
404
+
405
+ # Build results with full sentence metadata
406
+ results = []
407
+ for rank, (idx, score) in enumerate(zip(top_indices.tolist(), top_scores.tolist()), start=1):
408
+ sentence_id = filtered_sentence_ids[idx]
409
+ sentence_data = sentences_data.get(
410
+ sentence_id,
411
+ {"English Original": f"[Sentence data not found for {sentence_id}]", "Has PaintingCLIP Embedding": True},
412
+ ).copy()
413
+ work_id = sentence_id.split("_")[0]
414
+ sentence_data.setdefault("Work", work_id)
415
+ results.append({
416
+ "id": sentence_id,
417
+ "score": float(score),
418
+ "english_original": sentence_data.get("English Original", "N/A"),
419
+ "work": work_id,
420
+ "rank": rank,
421
+ })
422
+
423
+ print(f"πŸ” run_inference returning {len(results)} results")
424
+ if results:
425
+ print(f"πŸ” First result: {results[0]}")
426
+ return results
427
 
428
  except Exception as e:
429
  print(f"❌ Error in run_inference: {e}")