Spaces:
Sleeping
Sleeping
Added logic for directly loading dataset for huggingface
Browse files
app.py
CHANGED
|
@@ -37,6 +37,28 @@ CHROMA_DIR = Path("chroma_db")
|
|
| 37 |
DEFAULT_TOPK = 10
|
| 38 |
MAX_TOPK = 60
|
| 39 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
# ββ Load once at startup ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
print(f"\nStarting up on device: {DEVICE}")
|
|
@@ -95,6 +117,7 @@ def retrieve(query: str, model_choice: str, top_k: int = DEFAULT_TOPK):
|
|
| 95 |
|
| 96 |
with torch.inference_mode():
|
| 97 |
output = model.get_text_features(**inputs)
|
|
|
|
| 98 |
text_features = output.pooler_output if hasattr(output, "pooler_output") else output
|
| 99 |
|
| 100 |
text_features = torch.nn.functional.normalize(text_features, dim=-1)
|
|
@@ -110,10 +133,7 @@ def retrieve(query: str, model_choice: str, top_k: int = DEFAULT_TOPK):
|
|
| 110 |
output_images = []
|
| 111 |
if results["distances"] and len(results["distances"]) > 0:
|
| 112 |
for i , (meta, dist) in enumerate(zip(results["metadatas"][0], results["distances"][0])):
|
| 113 |
-
|
| 114 |
-
if not img_path.exists():
|
| 115 |
-
continue
|
| 116 |
-
img = Image.open(img_path).convert("RGB")
|
| 117 |
caption = f"#{i + 1}"
|
| 118 |
output_images.append((img,caption))
|
| 119 |
|
|
|
|
| 37 |
DEFAULT_TOPK = 10
|
| 38 |
MAX_TOPK = 60
|
| 39 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 40 |
+
USE_LOCAL_IMAGES = IMAGES_DIR.exists()
|
| 41 |
+
|
| 42 |
+
if USE_LOCAL_IMAGES:
|
| 43 |
+
print(f"Image source: local disk ({IMAGES_DIR})\n")
|
| 44 |
+
dataset = None
|
| 45 |
+
else:
|
| 46 |
+
print("Image source: HuggingFace dataset (data/images/ not found locally)")
|
| 47 |
+
print("Loading Flickr8k β¦")
|
| 48 |
+
from datasets import load_dataset
|
| 49 |
+
dataset = load_dataset("jxie/flickr8k", split="train+validation+test")
|
| 50 |
+
print(f" Dataset ready: {len(dataset)} images.\n")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def load_image(meta: dict) -> Image.Image:
|
| 54 |
+
"""
|
| 55 |
+
Load an image from local disk or HuggingFace dataset depending on
|
| 56 |
+
what is available at runtime.
|
| 57 |
+
"""
|
| 58 |
+
if USE_LOCAL_IMAGES:
|
| 59 |
+
return Image.open(IMAGES_DIR / meta["filename"]).convert("RGB")
|
| 60 |
+
else:
|
| 61 |
+
return dataset[meta["dataset_index"]]["image"].convert("RGB")
|
| 62 |
|
| 63 |
# ββ Load once at startup ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 64 |
print(f"\nStarting up on device: {DEVICE}")
|
|
|
|
| 117 |
|
| 118 |
with torch.inference_mode():
|
| 119 |
output = model.get_text_features(**inputs)
|
| 120 |
+
# Gestisce output che potrebbero differire leggermente tra architetture
|
| 121 |
text_features = output.pooler_output if hasattr(output, "pooler_output") else output
|
| 122 |
|
| 123 |
text_features = torch.nn.functional.normalize(text_features, dim=-1)
|
|
|
|
| 133 |
output_images = []
|
| 134 |
if results["distances"] and len(results["distances"]) > 0:
|
| 135 |
for i , (meta, dist) in enumerate(zip(results["metadatas"][0], results["distances"][0])):
|
| 136 |
+
img = load_image(meta)
|
|
|
|
|
|
|
|
|
|
| 137 |
caption = f"#{i + 1}"
|
| 138 |
output_images.append((img,caption))
|
| 139 |
|