Andy-6 commited on
Commit
3062100
Β·
1 Parent(s): 6041dbd

Added logic for directly loading dataset for huggingface

Browse files
Files changed (1) hide show
  1. app.py +24 -4
app.py CHANGED
@@ -37,6 +37,28 @@ CHROMA_DIR = Path("chroma_db")
37
  DEFAULT_TOPK = 10
38
  MAX_TOPK = 60
39
  # ──────────────────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  # ── Load once at startup ──────────────────────────────────────────────────────
42
  print(f"\nStarting up on device: {DEVICE}")
@@ -95,6 +117,7 @@ def retrieve(query: str, model_choice: str, top_k: int = DEFAULT_TOPK):
95
 
96
  with torch.inference_mode():
97
  output = model.get_text_features(**inputs)
 
98
  text_features = output.pooler_output if hasattr(output, "pooler_output") else output
99
 
100
  text_features = torch.nn.functional.normalize(text_features, dim=-1)
@@ -110,10 +133,7 @@ def retrieve(query: str, model_choice: str, top_k: int = DEFAULT_TOPK):
110
  output_images = []
111
  if results["distances"] and len(results["distances"]) > 0:
112
  for i , (meta, dist) in enumerate(zip(results["metadatas"][0], results["distances"][0])):
113
- img_path = IMAGES_DIR / meta["filename"]
114
- if not img_path.exists():
115
- continue
116
- img = Image.open(img_path).convert("RGB")
117
  caption = f"#{i + 1}"
118
  output_images.append((img,caption))
119
 
 
37
  DEFAULT_TOPK = 10
38
  MAX_TOPK = 60
39
  # ──────────────────────────────────────────────────────────────────────────────
40
+ USE_LOCAL_IMAGES = IMAGES_DIR.exists()
41
+
42
+ if USE_LOCAL_IMAGES:
43
+ print(f"Image source: local disk ({IMAGES_DIR})\n")
44
+ dataset = None
45
+ else:
46
+ print("Image source: HuggingFace dataset (data/images/ not found locally)")
47
+ print("Loading Flickr8k …")
48
+ from datasets import load_dataset
49
+ dataset = load_dataset("jxie/flickr8k", split="train+validation+test")
50
+ print(f" Dataset ready: {len(dataset)} images.\n")
51
+
52
+
53
+ def load_image(meta: dict) -> Image.Image:
54
+ """
55
+ Load an image from local disk or HuggingFace dataset depending on
56
+ what is available at runtime.
57
+ """
58
+ if USE_LOCAL_IMAGES:
59
+ return Image.open(IMAGES_DIR / meta["filename"]).convert("RGB")
60
+ else:
61
+ return dataset[meta["dataset_index"]]["image"].convert("RGB")
62
 
63
  # ── Load once at startup ──────────────────────────────────────────────────────
64
  print(f"\nStarting up on device: {DEVICE}")
 
117
 
118
  with torch.inference_mode():
119
  output = model.get_text_features(**inputs)
120
+ # Gestisce output che potrebbero differire leggermente tra architetture
121
  text_features = output.pooler_output if hasattr(output, "pooler_output") else output
122
 
123
  text_features = torch.nn.functional.normalize(text_features, dim=-1)
 
133
  output_images = []
134
  if results["distances"] and len(results["distances"]) > 0:
135
  for i , (meta, dist) in enumerate(zip(results["metadatas"][0], results["distances"][0])):
136
+ img = load_image(meta)
 
 
 
137
  caption = f"#{i + 1}"
138
  output_images.append((img,caption))
139