EYEDOL commited on
Commit
30a56dc
·
verified ·
1 Parent(s): 04ef5f9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import json
4
+ import numpy as np
5
+ from PIL import Image
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from transformers import AutoProcessor, AutoModel
9
+ import faiss
10
+ import gradio as gr
11
+
12
+ # CONFIG - make sure paths match those produced by build_index.py
13
+ MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned"
14
+ FAISS_DIR = "faiss_data"
15
+ INDEX_FILE = os.path.join(FAISS_DIR, "texts.faiss")
16
+ TEXTS_JSONL = os.path.join(FAISS_DIR, "texts.jsonl")
17
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+ TOP_K = 5
19
+
20
+ # Load metadata texts into memory
21
+ texts = []
22
+ with open(TEXTS_JSONL, "r", encoding="utf-8") as f:
23
+ for line in f:
24
+ obj = json.loads(line.strip())
25
+ texts.append(obj.get("text", ""))
26
+
27
+ print(f"Loaded {len(texts)} texts.")
28
+
29
+ # Load FAISS index
30
+ print("Loading FAISS index...")
31
+ index = faiss.read_index(INDEX_FILE) # IndexFlatIP saved previously
32
+ # If index is on CPU but you want to use GPU inference in Space, you can move to GPU if available and faiss-gpu installed.
33
+
34
+ # Load model + processor
35
+ print("Loading model & processor...")
36
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
37
+ model = AutoModel.from_pretrained(MODEL_ID).to(DEVICE)
38
+ model.eval()
39
+
40
+ def search_image(image: Image.Image, top_k: int = TOP_K):
41
+ # Preprocess image
42
+ inputs = processor(images=image.convert("RGB"), return_tensors="pt").to(DEVICE)
43
+ with torch.no_grad():
44
+ img_embed = model.get_image_features(**inputs) # (1, D)
45
+ img_embed = img_embed / img_embed.norm(p=2, dim=-1, keepdim=True)
46
+ img_vec = img_embed.cpu().numpy().astype('float32') # shape (1, D)
47
+
48
+ # Query FAISS (index expects float32)
49
+ D, I = index.search(img_vec, top_k) # D=distance matrix (inner product), I=indices
50
+ results = []
51
+ for score, idx in zip(D[0], I[0]):
52
+ if idx < 0:
53
+ continue
54
+ text = texts[idx] if idx < len(texts) else ""
55
+ # score is inner product cosine since vectors were normalized (range -1..1)
56
+ results.append({"text": text, "score": float(score)})
57
+ return results
58
+
59
+ # Build Gradio UI
60
+ def infer_and_format(file, top_k):
61
+ if file is None:
62
+ return "Upload an image", None
63
+ image = Image.open(file).convert("RGB")
64
+ results = search_image(image, top_k)
65
+ # build HTML or simple text output
66
+ lines = []
67
+ for i, r in enumerate(results, 1):
68
+ lines.append(f"<b>Rank {i}</b> — score: {r['score']:.4f}<br>{r['text']}")
69
+ html = "<br><br>".join(lines)
70
+ return html, image
71
+
72
+ with gr.Blocks() as demo:
73
+ gr.Markdown("# Image → Retrieved Texts")
74
+ with gr.Row():
75
+ with gr.Column(scale=1):
76
+ img_in = gr.Image(type="filepath", label="Upload image")
77
+ k_slider = gr.Slider(1, 10, value=TOP_K, step=1, label="Top K")
78
+ run_btn = gr.Button("Retrieve")
79
+ with gr.Column(scale=1):
80
+ out_html = gr.HTML()
81
+ out_img = gr.Image(label="Input image (preview)")
82
+
83
+ run_btn.click(infer_and_format, inputs=[img_in, k_slider], outputs=[out_html, out_img])
84
+
85
+ if __name__ == "__main__":
86
+ demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))