EYEDOL commited on
Commit
8576118
·
verified ·
1 Parent(s): 73ef8da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -39
app.py CHANGED
@@ -4,83 +4,82 @@ import json
4
  import numpy as np
5
  from PIL import Image
6
  import torch
7
- import torch.nn.functional as F
8
  from transformers import AutoProcessor, AutoModel
9
- import faiss
10
  import gradio as gr
11
 
12
- # CONFIG - make sure paths match those produced by build_index.py
13
  MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned"
14
- FAISS_DIR = "faiss_data"
15
- INDEX_FILE = os.path.join(FAISS_DIR, "texts.faiss")
16
- TEXTS_JSONL = os.path.join(FAISS_DIR, "texts.jsonl")
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
- TOP_K = 5
19
 
20
- # Load metadata texts into memory
21
  texts = []
22
- with open(TEXTS_JSONL, "r", encoding="utf-8") as f:
23
  for line in f:
24
  obj = json.loads(line.strip())
25
  texts.append(obj.get("text", ""))
26
 
27
- print(f"Loaded {len(texts)} texts.")
 
 
 
28
 
29
- # Load FAISS index
30
- print("Loading FAISS index...")
31
- index = faiss.read_index(INDEX_FILE) # IndexFlatIP saved previously
32
- # If index is on CPU but you want to use GPU inference in Space, you can move to GPU if available and faiss-gpu installed.
 
 
33
 
34
- # Load model + processor
35
- print("Loading model & processor...")
36
  processor = AutoProcessor.from_pretrained(MODEL_ID)
37
  model = AutoModel.from_pretrained(MODEL_ID).to(DEVICE)
38
  model.eval()
39
 
40
- def search_image(image: Image.Image, top_k: int = TOP_K):
41
- # Preprocess image
 
 
 
42
  inputs = processor(images=image.convert("RGB"), return_tensors="pt").to(DEVICE)
43
  with torch.no_grad():
44
  img_embed = model.get_image_features(**inputs) # (1, D)
45
  img_embed = img_embed / img_embed.norm(p=2, dim=-1, keepdim=True)
46
- img_vec = img_embed.cpu().numpy().astype('float32') # shape (1, D)
47
 
48
- # Query FAISS (index expects float32)
49
- D, I = index.search(img_vec, top_k) # D=distance matrix (inner product), I=indices
 
50
  results = []
51
- for score, idx in zip(D[0], I[0]):
52
- if idx < 0:
53
- continue
54
  text = texts[idx] if idx < len(texts) else ""
55
- # score is inner product cosine since vectors were normalized (range -1..1)
56
- results.append({"text": text, "score": float(score)})
57
- return results
58
 
59
- # Build Gradio UI
60
- def infer_and_format(file, top_k):
61
- if file is None:
62
- return "Upload an image", None
63
- image = Image.open(file).convert("RGB")
64
- results = search_image(image, top_k)
65
- # build HTML or simple text output
66
  lines = []
67
- for i, r in enumerate(results, 1):
68
  lines.append(f"<b>Rank {i}</b> — score: {r['score']:.4f}<br>{r['text']}")
69
  html = "<br><br>".join(lines)
70
  return html, image
71
 
 
72
  with gr.Blocks() as demo:
73
- gr.Markdown("# Image → Retrieved Texts")
74
  with gr.Row():
75
  with gr.Column(scale=1):
76
- img_in = gr.Image(type="filepath", label="Upload image")
77
- k_slider = gr.Slider(1, 10, value=TOP_K, step=1, label="Top K")
78
  run_btn = gr.Button("Retrieve")
79
  with gr.Column(scale=1):
80
  out_html = gr.HTML()
81
  out_img = gr.Image(label="Input image (preview)")
82
 
83
- run_btn.click(infer_and_format, inputs=[img_in, k_slider], outputs=[out_html, out_img])
84
 
85
  if __name__ == "__main__":
86
  demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
 
4
  import numpy as np
5
  from PIL import Image
6
  import torch
 
7
  from transformers import AutoProcessor, AutoModel
8
+ from sklearn.neighbors import NearestNeighbors
9
  import gradio as gr
10
 
11
+ # CONFIG
12
  MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned"
13
+ DATA_DIR = "faiss_free_data"
14
+ EMBEDS_FILE = os.path.join(DATA_DIR, "text_embeds.npy")
15
+ TEXTS_FILE = os.path.join(DATA_DIR, "texts.jsonl")
16
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
+ DEFAULT_TOPK = 5
18
 
19
+ # ---- Load texts metadata
20
  texts = []
21
+ with open(TEXTS_FILE, "r", encoding="utf-8") as f:
22
  for line in f:
23
  obj = json.loads(line.strip())
24
  texts.append(obj.get("text", ""))
25
 
26
+ # ---- Load embeddings
27
+ print("Loading embeddings...")
28
+ embs = np.load(EMBEDS_FILE) # shape (N, D), dtype float32
29
+ print("Embeddings loaded:", embs.shape)
30
 
31
+ # ---- Build (or load) NearestNeighbors index
32
+ # We use metric='cosine' so kneighbors returns cosine *distance* (range 0..2)
33
+ # We'll convert to similarity: sim = 1 - distance (works when embeddings were normalized)
34
+ nn = NearestNeighbors(n_neighbors=DEFAULT_TOPK, metric="cosine", n_jobs=-1)
35
+ nn.fit(embs)
36
+ print("NearestNeighbors ready.")
37
 
38
+ # ---- Load model & processor
 
39
  processor = AutoProcessor.from_pretrained(MODEL_ID)
40
  model = AutoModel.from_pretrained(MODEL_ID).to(DEVICE)
41
  model.eval()
42
 
43
+ def retrieve_texts_from_image(image: Image.Image, top_k: int = DEFAULT_TOPK):
44
+ if image is None:
45
+ return "No image uploaded", None
46
+
47
+ # Compute image embedding
48
  inputs = processor(images=image.convert("RGB"), return_tensors="pt").to(DEVICE)
49
  with torch.no_grad():
50
  img_embed = model.get_image_features(**inputs) # (1, D)
51
  img_embed = img_embed / img_embed.norm(p=2, dim=-1, keepdim=True)
52
+ img_vec = img_embed.cpu().numpy().astype("float32") # (1, D)
53
 
54
+ # Query NN
55
+ distances, indices = nn.kneighbors(img_vec, n_neighbors=top_k)
56
+ # sklearn returns cosine distances: dist = 1 - cosine_similarity (if vectors normalized)
57
  results = []
58
+ for dist, idx in zip(distances[0], indices[0]):
59
+ sim = 1.0 - float(dist) # similarity score in approx range [-1..1], typically [0..1]
 
60
  text = texts[idx] if idx < len(texts) else ""
61
+ results.append({"text": text, "score": sim, "id": int(idx)})
 
 
62
 
63
+ # format HTML
 
 
 
 
 
 
64
  lines = []
65
+ for i, r in enumerate(results, start=1):
66
  lines.append(f"<b>Rank {i}</b> — score: {r['score']:.4f}<br>{r['text']}")
67
  html = "<br><br>".join(lines)
68
  return html, image
69
 
70
+ # ---- Gradio UI
71
  with gr.Blocks() as demo:
72
+ gr.Markdown("# Image → Retrieved Texts (NO FAISS)")
73
  with gr.Row():
74
  with gr.Column(scale=1):
75
+ img_in = gr.Image(type="pil", label="Upload image")
76
+ k_slider = gr.Slider(1, 20, value=DEFAULT_TOPK, step=1, label="Top K")
77
  run_btn = gr.Button("Retrieve")
78
  with gr.Column(scale=1):
79
  out_html = gr.HTML()
80
  out_img = gr.Image(label="Input image (preview)")
81
 
82
+ run_btn.click(retrieve_texts_from_image, inputs=[img_in, k_slider], outputs=[out_html, out_img])
83
 
84
  if __name__ == "__main__":
85
  demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))