EYEDOL commited on
Commit
94b81b9
·
verified ·
1 Parent(s): cf1d39e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -74
app.py CHANGED
@@ -1,12 +1,15 @@
1
- # app.py (CPU-friendly, preloaded SigLip + Llava with robust loading)
2
  import os
3
- # FORCE CPU: must be set before importing torch/transformers
4
- os.environ["CUDA_VISIBLE_DEVICES"] = ""
5
- os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
6
 
7
  import sys
8
  import traceback
9
- from typing import List, Tuple
 
 
 
10
 
11
  import torch
12
  import torch.nn.functional as F
@@ -17,23 +20,25 @@ import gradio as gr
17
  from tqdm import tqdm
18
 
19
  # -------------------------
20
- # Config - update these IDs as needed
21
  # -------------------------
22
  SIGLIP_MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned"
23
- LLAVA_MODEL_ID = "llava-hf/llava-1.5-7b-hf" # <-- replace with your HF repo ID if different
24
  DATASET_TEMPLATE = "EYEDOL/AGRILLAVA-image-text{}"
25
- NUM_DATASETS = 1 # set to 15 if you want all datasets loaded (startup memory/time increases)
26
  BATCH_SIZE = 16
27
  TOP_K_DEFAULT = 3
 
 
28
 
29
- # Device - CPU only
30
  device = torch.device("cpu")
31
- print("Running on device:", device)
32
 
33
  # -------------------------
34
- # Load dataset and SigLip
35
  # -------------------------
36
- print("Loading datasets and computing SigLip text embeddings (CPU)...")
37
  texts_all: List[str] = []
38
  for i in range(1, NUM_DATASETS + 1):
39
  ds = load_dataset(DATASET_TEMPLATE.format(i), split="train")
@@ -43,41 +48,36 @@ siglip_processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_ID)
43
  siglip_model = AutoModel.from_pretrained(SIGLIP_MODEL_ID).to(device)
44
  siglip_model.eval()
45
 
46
- # Precompute text embeddings (CPU)
47
- text_embeds_list = []
48
- for i in tqdm(range(0, len(texts_all), BATCH_SIZE), desc="Encoding texts (CPU)"):
49
  batch_texts = texts_all[i : i + BATCH_SIZE]
50
  inputs = siglip_processor(text=batch_texts, padding=True, truncation=True, return_tensors="pt")
51
- # inputs are on CPU
52
  with torch.no_grad():
53
  text_embeds = siglip_model.get_text_features(**inputs)
54
  text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
55
- text_embeds_list.append(text_embeds.cpu())
56
  del inputs, text_embeds
57
- if len(text_embeds_list) == 0:
58
- text_embeds_all = torch.empty((0, 0))
59
  else:
60
- text_embeds_all = torch.cat(text_embeds_list, dim=0)
61
- print(f"Finished encoding {len(texts_all)} texts. Embeddings shape: {text_embeds_all.shape}")
62
 
63
  # -------------------------
64
- # Load Llava model & tokenizer (robust)
65
- # Strategy:
66
- # 1) Try to import LlavaForCausalLM from installed llava package (recommended).
67
- # 2) If not available, try AutoModelForCausalLM.from_pretrained(..., trust_remote_code=True).
68
- # 3) If both fail, raise a clear error with instructions.
69
  # -------------------------
70
- llava_tokenizer = None
71
  llava_model = None
72
-
73
  load_errors = []
74
 
75
  # Attempt 1: local llava package (preferred)
76
  try:
77
- # Import here so we don't require the package unless we need it
78
  from llava.model import LlavaForCausalLM # type: ignore
79
 
80
- print("Found installed 'llava' package — loading LlavaForCausalLM from it (CPU)...")
81
  llava_tokenizer = AutoTokenizer.from_pretrained(LLAVA_MODEL_ID, use_fast=False)
82
  llava_model = LlavaForCausalLM.from_pretrained(
83
  LLAVA_MODEL_ID,
@@ -87,12 +87,15 @@ try:
87
  )
88
  llava_model.to(device)
89
  llava_model.eval()
90
- print("✅ LlavaForCausalLM loaded via local llava package.")
 
91
  except Exception as e_local:
92
  tb_local = traceback.format_exc()
93
  load_errors.append(("local_llava_import", tb_local))
94
- print("Local llava import/load failed — will attempt fallback (trust_remote_code=True).")
95
- # Attempt 2: trust_remote_code fallback
 
 
96
  try:
97
  print("Attempting AutoModelForCausalLM.from_pretrained(..., trust_remote_code=True) (CPU)...")
98
  llava_tokenizer = AutoTokenizer.from_pretrained(LLAVA_MODEL_ID, use_fast=False)
@@ -105,28 +108,56 @@ except Exception as e_local:
105
  )
106
  llava_model.to(device)
107
  llava_model.eval()
108
- print("✅ Llava model loaded via trust_remote_code fallback.")
109
- except Exception as e_fallback:
110
- tb_fallback = traceback.format_exc()
111
- load_errors.append(("fallback_trust_remote_code", tb_fallback))
112
- # Both failed — raise a helpful error describing how to fix
113
- err_msg = (
114
- "Failed to load the Llava model using both strategies.\n\n"
115
- "Recommended fixes:\n"
116
- "1) Add the LLaVA repo to requirements.txt so the `llava` package and LlavaForCausalLM are installed:\n"
117
- " git+https://github.com/haotian-liu/LLaVA.git@main\n"
118
- " Then rebuild your Space.\n\n"
119
- "2) If you prefer trust_remote_code, ensure the HF model repo supports `trust_remote_code=True` and\n"
120
- " that any repo-specific dependencies (listed in the repo README) are installed in requirements.txt.\n\n"
121
- "Debug details (tracebacks):\n\n"
122
- )
123
- for name, tb in load_errors:
124
- err_msg += f"--- {name} traceback ---\n{tb}\n"
125
- # raise RuntimeError with the composed message
126
- raise RuntimeError(err_msg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  # -------------------------
129
- # SigLip retrieval function
130
  # -------------------------
131
  def retrieve_top_k_texts(image: Image.Image, k: int = TOP_K_DEFAULT):
132
  inputs = siglip_processor(images=image, return_tensors="pt")
@@ -139,48 +170,59 @@ def retrieve_top_k_texts(image: Image.Image, k: int = TOP_K_DEFAULT):
139
  results = [(texts_all[idx.item()], float(score)) for idx, score in zip(topk.indices, topk.values)]
140
  return results
141
 
142
- # -------------------------
143
- # Llava answer function
144
- # -------------------------
145
  def llava_answer(image: Image.Image, retrieved_texts, question: str, max_tokens: int = 256):
146
- # Compose context: retrieved text + short instruction
147
  context_text = "\n".join([f"Retrieved Text: {t}" for t, _ in retrieved_texts])
148
  prompt = (
149
- "You are an agricultural assistant. Use the provided retrieved texts and the image context to answer the user's question.\n\n"
150
  f"Retrieved texts:\n{context_text}\n\n"
151
  f"User question: {question}\n\n"
152
- "Provide a concise, actionable answer and crop suggestions where appropriate."
153
  )
154
 
155
- inputs = llava_tokenizer(prompt, return_tensors="pt")
156
- # ensure tokens are on CPU
157
- inputs = {k: v.to(device) for k, v in inputs.items()}
158
- with torch.no_grad():
159
- output_ids = llava_model.generate(**inputs, max_new_tokens=max_tokens)
160
- response = llava_tokenizer.decode(output_ids[0], skip_special_tokens=True)
161
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  # -------------------------
164
- # Gradio pipeline
165
  # -------------------------
166
  def gradio_pipeline(image: Image.Image, question: str, k: int = TOP_K_DEFAULT):
167
  if image is None or not question:
168
  return None, "Please provide both an image and a question."
 
169
  retrieved = retrieve_top_k_texts(image, k=int(k))
170
  try:
171
  answer = llava_answer(image, retrieved, question)
172
  except Exception as e:
173
  tb = traceback.format_exc()
174
- answer = f"Error while generating answer: {e}\n\nTraceback:\n{tb}"
175
  return image, answer
176
 
177
- # -------------------------
178
- # Gradio app
179
- # -------------------------
180
- with gr.Blocks(title="Agri Image + Question → Llava Response (CPU)") as demo:
181
  gr.Markdown(
182
- "# Agri Image QA (CPU)\n\nUpload an agriculture image and ask a question. "
183
- "This Space preloads models and embeddings at startup for faster responses."
 
184
  )
185
  with gr.Row():
186
  img_in = gr.Image(type="pil")
 
1
+ # app.py — Robust CPU-friendly SigLip -> (Llava local OR HF-inference fallback) pipeline
2
  import os
3
+ # Force CPU before importing torch/transformers if you want CPU-only
4
+ os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
5
+ os.environ.setdefault("TRANSFORMERS_NO_ADVISORY_WARNINGS", "1")
6
 
7
  import sys
8
  import traceback
9
+ from typing import List, Tuple, Optional
10
+ import json
11
+ import requests
12
+ import time
13
 
14
  import torch
15
  import torch.nn.functional as F
 
20
  from tqdm import tqdm
21
 
22
  # -------------------------
23
+ # Config (update model ids & dataset count)
24
  # -------------------------
25
  SIGLIP_MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned"
26
+ LLAVA_MODEL_ID = "llava-hf/llava-1.5-7b-hf" # change if needed
27
  DATASET_TEMPLATE = "EYEDOL/AGRILLAVA-image-text{}"
28
+ NUM_DATASETS = 1 # set to 15 if you want full data (startup time/memory increases)
29
  BATCH_SIZE = 16
30
  TOP_K_DEFAULT = 3
31
+ HF_API_URL = f"https://api-inference.huggingface.co/models/{LLAVA_MODEL_ID}"
32
+ HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", None)
33
 
34
+ # Device
35
  device = torch.device("cpu")
36
+ print("Device:", device)
37
 
38
  # -------------------------
39
+ # Load SigLip dataset & model → precompute text embeddings at startup
40
  # -------------------------
41
+ print("Loading datasets and computing SigLip text embeddings (startup)...")
42
  texts_all: List[str] = []
43
  for i in range(1, NUM_DATASETS + 1):
44
  ds = load_dataset(DATASET_TEMPLATE.format(i), split="train")
 
48
  siglip_model = AutoModel.from_pretrained(SIGLIP_MODEL_ID).to(device)
49
  siglip_model.eval()
50
 
51
+ # compute embeddings
52
+ text_embeds_parts = []
53
+ for i in tqdm(range(0, len(texts_all), BATCH_SIZE), desc="Encoding texts"):
54
  batch_texts = texts_all[i : i + BATCH_SIZE]
55
  inputs = siglip_processor(text=batch_texts, padding=True, truncation=True, return_tensors="pt")
 
56
  with torch.no_grad():
57
  text_embeds = siglip_model.get_text_features(**inputs)
58
  text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
59
+ text_embeds_parts.append(text_embeds.cpu())
60
  del inputs, text_embeds
61
+ if text_embeds_parts:
62
+ text_embeds_all = torch.cat(text_embeds_parts, dim=0)
63
  else:
64
+ text_embeds_all = torch.empty((0, 0))
65
+ print(f"Encoded {len(texts_all)} texts. Embeddings shape: {text_embeds_all.shape}")
66
 
67
  # -------------------------
68
+ # Llava loading: try local package -> trust_remote_code -> HF Inference API (if token provided)
 
 
 
 
69
  # -------------------------
70
+ llava_tokenizer: Optional[AutoTokenizer] = None
71
  llava_model = None
72
+ llava_mode = None # 'local', 'trust_remote_code', or 'hf_api' or None
73
  load_errors = []
74
 
75
  # Attempt 1: local llava package (preferred)
76
  try:
77
+ # this import requires the LLaVA repo to be installed in the environment (requirements.txt)
78
  from llava.model import LlavaForCausalLM # type: ignore
79
 
80
+ print("Loading LlavaForCausalLM from installed 'llava' package (CPU)...")
81
  llava_tokenizer = AutoTokenizer.from_pretrained(LLAVA_MODEL_ID, use_fast=False)
82
  llava_model = LlavaForCausalLM.from_pretrained(
83
  LLAVA_MODEL_ID,
 
87
  )
88
  llava_model.to(device)
89
  llava_model.eval()
90
+ llava_mode = "local"
91
+ print("✅ Llava loaded from installed package.")
92
  except Exception as e_local:
93
  tb_local = traceback.format_exc()
94
  load_errors.append(("local_llava_import", tb_local))
95
+ print("Local llava import failed — will try trust_remote_code fallback. (see logs)")
96
+
97
+ # Attempt 2: trust_remote_code fallback
98
+ if llava_mode is None:
99
  try:
100
  print("Attempting AutoModelForCausalLM.from_pretrained(..., trust_remote_code=True) (CPU)...")
101
  llava_tokenizer = AutoTokenizer.from_pretrained(LLAVA_MODEL_ID, use_fast=False)
 
108
  )
109
  llava_model.to(device)
110
  llava_model.eval()
111
+ llava_mode = "trust_remote_code"
112
+ print("✅ Llava loaded via trust_remote_code fallback.")
113
+ except Exception as e_trust:
114
+ tb_trust = traceback.format_exc()
115
+ load_errors.append(("fallback_trust_remote_code", tb_trust))
116
+ print("trust_remote_code fallback failed.")
117
+
118
+ # Attempt 3: Hugging Face Inference API fallback (requires HUGGINGFACE_TOKEN)
119
+ if llava_mode is None and HUGGINGFACE_TOKEN:
120
+ # we won't load a model locally; will call inference API for generation
121
+ llava_mode = "hf_api"
122
+ print("No local model available. Will use Hugging Face Inference API for generation (HUGGINGFACE_TOKEN detected).")
123
+
124
+ # If still no method available, keep llava_mode None and continue — UI will show actionable message
125
+ if llava_mode is None:
126
+ print("WARNING: No Llava model available locally or via trust_remote_code, and no HUGGINGFACE_TOKEN found.")
127
+ print("App will start but generation will return an actionable error. See load_errors for tracebacks.")
128
+ for name, tb in load_errors:
129
+ print(f"--- {name} traceback ---")
130
+ print(tb)
131
+
132
+ # -------------------------
133
+ # Helper: call Hugging Face Inference API for text generation
134
+ # -------------------------
135
+ def call_hf_inference_api(prompt: str, max_new_tokens: int = 256, temperature: float = 0.0):
136
+ if not HUGGINGFACE_TOKEN:
137
+ raise RuntimeError("HUGGINGFACE_TOKEN not set; cannot call Hugging Face Inference API.")
138
+ headers = {"Authorization": f"Bearer {HUGGINGFACE_TOKEN}"}
139
+ payload = {
140
+ "inputs": prompt,
141
+ "parameters": {"max_new_tokens": max_new_tokens, "temperature": temperature},
142
+ "options": {"wait_for_model": True},
143
+ }
144
+ resp = requests.post(HF_API_URL, headers=headers, json=payload, timeout=300)
145
+ if resp.status_code != 200:
146
+ raise RuntimeError(f"HF Inference API error {resp.status_code}: {resp.text}")
147
+ data = resp.json()
148
+ # API returns list or dict depending on model; handle common shapes
149
+ if isinstance(data, list) and data and isinstance(data[0], dict) and "generated_text" in data[0]:
150
+ return data[0]["generated_text"]
151
+ if isinstance(data, dict) and "generated_text" in data:
152
+ return data["generated_text"]
153
+ # If the model returns a plain string or other structure:
154
+ if isinstance(data, str):
155
+ return data
156
+ # Fallback: try to stringify
157
+ return json.dumps(data)
158
 
159
  # -------------------------
160
+ # Retrieval & generation functions
161
  # -------------------------
162
  def retrieve_top_k_texts(image: Image.Image, k: int = TOP_K_DEFAULT):
163
  inputs = siglip_processor(images=image, return_tensors="pt")
 
170
  results = [(texts_all[idx.item()], float(score)) for idx, score in zip(topk.indices, topk.values)]
171
  return results
172
 
 
 
 
173
  def llava_answer(image: Image.Image, retrieved_texts, question: str, max_tokens: int = 256):
 
174
  context_text = "\n".join([f"Retrieved Text: {t}" for t, _ in retrieved_texts])
175
  prompt = (
176
+ "You are an agricultural assistant. Use the provided retrieved texts to answer concisely.\n\n"
177
  f"Retrieved texts:\n{context_text}\n\n"
178
  f"User question: {question}\n\n"
179
+ "Provide a concise, actionable answer and crop suggestions when applicable."
180
  )
181
 
182
+ if llava_mode in ("local", "trust_remote_code"):
183
+ # use the tokenizer + local model
184
+ inputs = llava_tokenizer(prompt, return_tensors="pt")
185
+ inputs = {k: v.to(device) for k, v in inputs.items()}
186
+ with torch.no_grad():
187
+ output_ids = llava_model.generate(**inputs, max_new_tokens=max_tokens)
188
+ resp = llava_tokenizer.decode(output_ids[0], skip_special_tokens=True)
189
+ return resp
190
+ elif llava_mode == "hf_api":
191
+ # Use HF Inference API
192
+ return call_hf_inference_api(prompt, max_new_tokens=max_tokens)
193
+ else:
194
+ # No model available — return actionable error for the UI
195
+ err = (
196
+ "No Llava model is available for generation.\n\n"
197
+ "Options to fix:\n"
198
+ "1) Install the LLaVA repo in requirements.txt and rebuild the Space:\n"
199
+ " git+https://github.com/haotian-liu/LLaVA.git@main\n"
200
+ "2) Or provide a Hugging Face API token as the HUGGINGFACE_TOKEN secret in Space settings so the app can\n"
201
+ f" fall back to the Inference API. Expected token env var name: HUGGINGFACE_TOKEN\n\n"
202
+ "Debug info (tracebacks were printed to Space logs at startup).\n"
203
+ )
204
+ return err
205
 
206
  # -------------------------
207
+ # Gradio pipeline + UI
208
  # -------------------------
209
  def gradio_pipeline(image: Image.Image, question: str, k: int = TOP_K_DEFAULT):
210
  if image is None or not question:
211
  return None, "Please provide both an image and a question."
212
+
213
  retrieved = retrieve_top_k_texts(image, k=int(k))
214
  try:
215
  answer = llava_answer(image, retrieved, question)
216
  except Exception as e:
217
  tb = traceback.format_exc()
218
+ answer = f"Error during generation: {e}\n\nTraceback:\n{tb}"
219
  return image, answer
220
 
221
+ with gr.Blocks(title="Agri Image + Question → Llava Response (robust)") as demo:
 
 
 
222
  gr.Markdown(
223
+ "## Agri Image QA\n\nThis app preloads SigLip embeddings at startup. "
224
+ "Generation uses a local Llava model if available, otherwise the Hugging Face Inference API "
225
+ "(requires HUGGINGFACE_TOKEN set in Space secrets)."
226
  )
227
  with gr.Row():
228
  img_in = gr.Image(type="pil")