EYEDOL commited on
Commit
43590b6
·
verified ·
1 Parent(s): f8c57ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -143
app.py CHANGED
@@ -1,203 +1,129 @@
1
  """
2
- Gradio Space app (app.py) — SigLip image -> text retrieval
3
 
4
- Place this file as `app.py` in your Hugging Face Space. Add the requirements listed below to `requirements.txt` in the Space.
 
 
 
 
 
5
 
6
- How it works
7
- - On startup it loads your concatenated datasets and the fine-tuned model `EYEDOL/siglipFULL-agri-finetuned`.
8
- - It precomputes (and caches) normalized text embeddings on CPU to save GPU memory.
9
- - The Gradio UI allows users to upload an image, view it, and returns the top-k matched text captions.
10
-
11
- Notes for Spaces
12
- - If your model or datasets are private, add a `HUGGINGFACE_TOKEN` secret in the Space settings and set `USE_HF_TOKEN = True` below.
13
- - If you select a GPU runtime for the Space, the app will use it if available.
14
  """
15
 
16
  import os
17
- import tempfile
18
  from functools import lru_cache
19
  from typing import List, Tuple
20
 
21
  import gradio as gr
22
  import torch
23
  import torch.nn.functional as F
24
- from datasets import concatenate_datasets, load_dataset
25
  from PIL import Image
26
- from transformers import AutoModel, AutoProcessor
27
  from tqdm import tqdm
28
 
29
  # -------------------------
30
  # Config
31
  # -------------------------
32
- MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned"
33
- DATASET_TEMPLATE = "EYEDOL/AGRILLAVA-image-text{}" # datasets 1..15
34
- NUM_DATASETS = 2
35
- BATCH_SIZE = 16 # for text encoding
36
- USE_HF_TOKEN = False # set True if model/datasets are private and you will pass token via environment
37
  TOP_K_DEFAULT = 3
38
 
39
- # Look for HF token in environment (Spaces -> Settings -> Secrets set HUGGINGFACE_TOKEN)
40
- HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", None)
41
- if HF_TOKEN:
42
- USE_HF_TOKEN = True
43
-
44
- # -------------------------
45
  # Device
46
- # -------------------------
47
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
 
49
  # -------------------------
50
- # Utility: load & preprocess datasets
51
  # -------------------------
52
  @lru_cache(maxsize=1)
53
- def load_and_merge_datasets(num_datasets: int = NUM_DATASETS) -> List[str]:
54
  texts = []
55
- for i in range(1, num_datasets + 1):
56
- name = DATASET_TEMPLATE.format(i)
57
- try:
58
- ds = load_dataset(name, split="train")
59
- # expect a field 'text'
60
- texts.extend(list(ds["text"]))
61
- except Exception as e:
62
- print(f"Warning: failed to load {name}: {e}")
63
- return texts
64
 
65
- # -------------------------
66
- # Load model & processor
67
- # -------------------------
68
- @lru_cache(maxsize=1)
69
- def load_model_and_processor(model_id: str = MODEL_ID, use_token: bool = USE_HF_TOKEN):
70
- kwargs = {}
71
- if use_token and HF_TOKEN:
72
- kwargs["use_auth_token"] = HF_TOKEN
73
- processor = AutoProcessor.from_pretrained(model_id, **kwargs)
74
- model = AutoModel.from_pretrained(model_id, **kwargs)
75
- model.to(device)
76
  model.eval()
77
- return processor, model
78
-
79
- # -------------------------
80
- # Precompute text embeddings (CPU) and return tensors + raw texts
81
- # -------------------------
82
- @lru_cache(maxsize=1)
83
- def precompute_text_embeddings(texts_tuple: Tuple[str, ...]):
84
- # convert tuple back to list
85
- texts = list(texts_tuple)
86
- processor, model = load_model_and_processor()
87
 
88
  text_embeds_all = []
89
- for i in tqdm(range(0, len(texts), BATCH_SIZE), desc="Encoding texts (startup)"):
90
- batch_texts = texts[i : i + BATCH_SIZE]
91
- # processor returns PyTorch tensors by default
92
- inputs = processor(text=batch_texts, padding=True, truncation=True, return_tensors="pt")
93
- # encode on device then move embeddings to CPU
94
- inputs = {k: v.to(device) for k, v in inputs.items()}
95
  with torch.no_grad():
96
  text_embeds = model.get_text_features(**inputs)
97
  text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
98
  text_embeds_all.append(text_embeds.cpu())
99
  del inputs, text_embeds
100
- if torch.cuda.is_available():
101
- torch.cuda.empty_cache()
102
 
103
- if len(text_embeds_all) == 0:
104
- return torch.empty((0, 0)), []
105
  text_embeds_all = torch.cat(text_embeds_all, dim=0)
106
- return text_embeds_all, texts
107
 
108
  # -------------------------
109
- # High-level initialization step (runs on import)
110
- # -------------------------
111
- print("Starting app: loading data and model — this may take a minute...")
112
- raw_texts = load_and_merge_datasets()
113
- print(f"Loaded {len(raw_texts)} text captions from datasets (merged).")
114
- text_embeds_all, texts_all = precompute_text_embeddings(tuple(raw_texts))
115
- print(f"Precomputed text embeddings: {text_embeds_all.shape}")
116
- processor, model = load_model_and_processor()
117
-
118
  # -------------------------
119
- # Retrieval function
120
- # -------------------------
121
-
122
- def retrieve_top_k_texts_from_image(image: Image.Image, k: int = TOP_K_DEFAULT) -> List[Tuple[str, float]]:
123
- # prepare image
124
- inputs = processor(images=image, return_tensors="pt")
125
- inputs = {k: v.to(device) for k, v in inputs.items()}
126
-
127
  with torch.no_grad():
128
  img_embed = model.get_image_features(**inputs)
129
  img_embed = img_embed / img_embed.norm(p=2, dim=-1, keepdim=True)
130
 
131
- # move to CPU and compute similarity with precomputed text embeddings
132
  sims = F.cosine_similarity(img_embed.cpu(), text_embeds_all)
133
  topk = torch.topk(sims, k)
134
- results = []
135
- for i in range(k):
136
- idx = topk.indices[i].item()
137
- score = topk.values[i].item()
138
- results.append((texts_all[idx], float(score)))
139
  return results
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  # -------------------------
142
  # Gradio interface
143
  # -------------------------
144
 
145
- def gradio_predict(img, k=TOP_K_DEFAULT):
146
- if img is None:
147
- return None, "No image provided."
148
- if isinstance(img, str):
149
- image = Image.open(img).convert("RGB")
150
- else:
151
- image = img.convert("RGB")
152
-
153
- results = retrieve_top_k_texts_from_image(image, k=int(k))
154
- # format for display
155
- formatted = "\n\n".join([f"Rank {i+1}: {t}\n(score={s:.4f})" for i, (t, s) in enumerate(results)])
156
- return image, formatted
157
-
158
- with gr.Blocks(title="SigLip Image -> Text Retriever") as demo:
159
- gr.Markdown("# SigLip Image → Text retrieval demo\nUpload an image and get the top-k matching texts from the dataset.")
160
  with gr.Row():
161
  img_in = gr.Image(type="pil")
162
  out_img = gr.Image(type="pil", label="Image")
163
- txt_out = gr.Textbox(label="Top-k matches", lines=8)
164
- k_slider = gr.Slider(minimum=1, maximum=10, step=1, value=TOP_K_DEFAULT, label="Top k")
165
- run_btn = gr.Button("Retrieve")
 
166
 
167
- run_btn.click(fn=gradio_predict, inputs=[img_in, k_slider], outputs=[out_img, txt_out])
168
 
169
- # Expose the app
170
  if __name__ == "__main__":
171
  demo.launch(server_name="0.0.0.0", share=False)
172
-
173
-
174
- # -------------------------
175
- # requirements.txt (place in your Space as requirements.txt):
176
- # -------------------------
177
- # torch
178
- # torchvision
179
- # transformers==4.44.2
180
- # datasets
181
- # gradio
182
- # huggingface_hub
183
- # accelerate
184
- # pillow
185
- # tqdm
186
-
187
- # -------------------------
188
- # Quick setup checklist for HF Space
189
- # -------------------------
190
- # 1. Create a new Space (Gradio). In the Settings -> Hardware choose GPU if available and you expect faster inference.
191
- # 2. Add the requirements.txt (as above).
192
- # 3. If the model/datasets are private, go to Settings -> Secrets and add HUGGINGFACE_TOKEN with a token that has access.
193
- # 4. If you set the secret, the app will automatically pick it up from the HUGGINGFACE_TOKEN env var.
194
- # 5. Commit this app.py and requirements.txt to the Space and your app should start.
195
-
196
- # -------------------------
197
- # Tips & Troubleshooting
198
- # -------------------------
199
- # - Startup time may be long (model download, dataset download, text embedding encoding). Consider saving precomputed text embeddings
200
- # to a file (np.save / torch.save) and loading them to speed startup. In Spaces persistent storage is /workspace or /root/.cache.
201
- # - If memory is tight, reduce NUM_DATASETS or BATCH_SIZE or compute embeddings offline and upload a precomputed tensor.
202
- # - Avoid printing too many things in Spaces logs to reduce noise.
203
- # -------------------------
 
1
  """
2
+ Gradio Space app (app.py) — SigLip Image + Question → Llava Response
3
 
4
+ Pipeline:
5
+ 1. User uploads an agriculture image.
6
+ 2. User asks a question about the image.
7
+ 3. SigLip model retrieves top-k text captions relevant to the image.
8
+ 4. The retrieved text, original image, and user's question are sent to a Llava model.
9
+ 5. Llava generates a context-aware response with crop suggestions or explanations.
10
 
11
+ This updated app handles both the image retrieval and multi-modal question answering.
 
 
 
 
 
 
 
12
  """
13
 
14
  import os
 
15
  from functools import lru_cache
16
  from typing import List, Tuple
17
 
18
  import gradio as gr
19
  import torch
20
  import torch.nn.functional as F
21
+ from datasets import load_dataset
22
  from PIL import Image
23
+ from transformers import AutoProcessor, AutoModel
24
  from tqdm import tqdm
25
 
26
  # -------------------------
27
  # Config
28
  # -------------------------
29
+ SIGLIP_MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned"
30
+ LLAVA_MODEL_ID = "llava-hf/llava-1.5-7b-hf" # replace with actual model
31
+ DATASET_TEMPLATE = "EYEDOL/AGRILLAVA-image-text{}"
32
+ NUM_DATASETS = 1
33
+ BATCH_SIZE = 16
34
  TOP_K_DEFAULT = 3
35
 
 
 
 
 
 
 
36
  # Device
 
37
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
 
39
  # -------------------------
40
+ # SigLip: load & precompute text embeddings
41
  # -------------------------
42
  @lru_cache(maxsize=1)
43
+ def load_singlip_texts_and_embeddings():
44
  texts = []
45
+ for i in range(1, NUM_DATASETS + 1):
46
+ ds = load_dataset(DATASET_TEMPLATE.format(i), split="train")
47
+ texts.extend(ds["text"])
 
 
 
 
 
 
48
 
49
+ processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_ID)
50
+ model = AutoModel.from_pretrained(SIGLIP_MODEL_ID).to(device)
 
 
 
 
 
 
 
 
 
51
  model.eval()
 
 
 
 
 
 
 
 
 
 
52
 
53
  text_embeds_all = []
54
+ for i in tqdm(range(0, len(texts), BATCH_SIZE), desc="Encoding texts"):
55
+ batch_texts = texts[i:i+BATCH_SIZE]
56
+ inputs = processor(text=batch_texts, padding=True, truncation=True, return_tensors="pt").to(device)
 
 
 
57
  with torch.no_grad():
58
  text_embeds = model.get_text_features(**inputs)
59
  text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
60
  text_embeds_all.append(text_embeds.cpu())
61
  del inputs, text_embeds
62
+ torch.cuda.empty_cache()
 
63
 
 
 
64
  text_embeds_all = torch.cat(text_embeds_all, dim=0)
65
+ return processor, model, texts, text_embeds_all
66
 
67
  # -------------------------
68
+ # SigLip retrieval
 
 
 
 
 
 
 
 
69
  # -------------------------
70
+ def retrieve_top_k_texts(image: Image.Image, k=TOP_K_DEFAULT):
71
+ processor, model, texts_all, text_embeds_all = load_singlip_texts_and_embeddings()
72
+ inputs = processor(images=image, return_tensors="pt").to(device)
 
 
 
 
 
73
  with torch.no_grad():
74
  img_embed = model.get_image_features(**inputs)
75
  img_embed = img_embed / img_embed.norm(p=2, dim=-1, keepdim=True)
76
 
 
77
  sims = F.cosine_similarity(img_embed.cpu(), text_embeds_all)
78
  topk = torch.topk(sims, k)
79
+ results = [(texts_all[idx.item()], float(score)) for idx, score in zip(topk.indices, topk.values)]
 
 
 
 
80
  return results
81
 
82
+ # -------------------------
83
+ # Llava response
84
+ # -------------------------
85
+ @lru_cache(maxsize=1)
86
+ def load_llava_model():
87
+ from transformers import AutoModelForCausalLM, AutoTokenizer
88
+ tokenizer = AutoTokenizer.from_pretrained(LLAVA_MODEL_ID)
89
+ model = AutoModelForCausalLM.from_pretrained(LLAVA_MODEL_ID).to(device)
90
+ model.eval()
91
+ return tokenizer, model
92
+
93
+ def llava_answer(image: Image.Image, retrieved_texts: List[str], question: str, max_tokens=256):
94
+ tokenizer, model = load_llava_model()
95
+ context_text = "\n".join([f"Retrieved Text: {t}" for t, _ in retrieved_texts])
96
+ prompt = f"Given the image and the following texts:\n{context_text}\nUser Question: {question}\nProvide a detailed answer and crop suggestions."
97
+
98
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
99
+ with torch.no_grad():
100
+ output_ids = model.generate(**inputs, max_new_tokens=max_tokens)
101
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
102
+ return response
103
+
104
  # -------------------------
105
  # Gradio interface
106
  # -------------------------
107
 
108
+ def gradio_pipeline(image: Image.Image, question: str, k: int = TOP_K_DEFAULT):
109
+ if image is None or not question:
110
+ return None, "Please provide both image and question."
111
+
112
+ retrieved_texts = retrieve_top_k_texts(image, k=int(k))
113
+ response = llava_answer(image, retrieved_texts, question)
114
+ return image, response
115
+
116
+ with gr.Blocks(title="Agri Image + Question → Llava Response") as demo:
117
+ gr.Markdown("# Agri Image Question Answering\nUpload an agriculture image, ask a question, and get context-aware crop suggestions.")
 
 
 
 
 
118
  with gr.Row():
119
  img_in = gr.Image(type="pil")
120
  out_img = gr.Image(type="pil", label="Image")
121
+ question_input = gr.Textbox(label="Question about the image", lines=2)
122
+ k_slider = gr.Slider(minimum=1, maximum=10, step=1, value=TOP_K_DEFAULT, label="Top-k retrieval")
123
+ txt_out = gr.Textbox(label="Llava Response", lines=8)
124
+ run_btn = gr.Button("Generate Answer")
125
 
126
+ run_btn.click(fn=gradio_pipeline, inputs=[img_in, question_input, k_slider], outputs=[out_img, txt_out])
127
 
 
128
  if __name__ == "__main__":
129
  demo.launch(server_name="0.0.0.0", share=False)