EYEDOL commited on
Commit
a48863e
Β·
verified Β·
1 Parent(s): 94b81b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -43
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py β€” Robust CPU-friendly SigLip -> (Llava local OR HF-inference fallback) pipeline
2
  import os
3
  # Force CPU before importing torch/transformers if you want CPU-only
4
  os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
@@ -6,37 +6,43 @@ os.environ.setdefault("TRANSFORMERS_NO_ADVISORY_WARNINGS", "1")
6
 
7
  import sys
8
  import traceback
9
- from typing import List, Tuple, Optional
10
  import json
11
- import requests
12
- import time
13
 
 
14
  import torch
15
  import torch.nn.functional as F
16
  from datasets import load_dataset
17
- from transformers import AutoProcessor, AutoModel, AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
18
  from PIL import Image
19
  import gradio as gr
20
  from tqdm import tqdm
21
 
22
  # -------------------------
23
- # Config (update model ids & dataset count)
24
  # -------------------------
25
  SIGLIP_MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned"
26
  LLAVA_MODEL_ID = "llava-hf/llava-1.5-7b-hf" # change if needed
27
  DATASET_TEMPLATE = "EYEDOL/AGRILLAVA-image-text{}"
28
- NUM_DATASETS = 1 # set to 15 if you want full data (startup time/memory increases)
29
  BATCH_SIZE = 16
30
  TOP_K_DEFAULT = 3
31
- HF_API_URL = f"https://api-inference.huggingface.co/models/{LLAVA_MODEL_ID}"
 
 
32
  HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", None)
33
 
34
- # Device
35
  device = torch.device("cpu")
36
- print("Device:", device)
37
 
38
  # -------------------------
39
- # Load SigLip dataset & model β†’ precompute text embeddings at startup
40
  # -------------------------
41
  print("Loading datasets and computing SigLip text embeddings (startup)...")
42
  texts_all: List[str] = []
@@ -48,9 +54,9 @@ siglip_processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_ID)
48
  siglip_model = AutoModel.from_pretrained(SIGLIP_MODEL_ID).to(device)
49
  siglip_model.eval()
50
 
51
- # compute embeddings
52
  text_embeds_parts = []
53
- for i in tqdm(range(0, len(texts_all), BATCH_SIZE), desc="Encoding texts"):
54
  batch_texts = texts_all[i : i + BATCH_SIZE]
55
  inputs = siglip_processor(text=batch_texts, padding=True, truncation=True, return_tensors="pt")
56
  with torch.no_grad():
@@ -69,7 +75,7 @@ print(f"Encoded {len(texts_all)} texts. Embeddings shape: {text_embeds_all.shape
69
  # -------------------------
70
  llava_tokenizer: Optional[AutoTokenizer] = None
71
  llava_model = None
72
- llava_mode = None # 'local', 'trust_remote_code', or 'hf_api' or None
73
  load_errors = []
74
 
75
  # Attempt 1: local llava package (preferred)
@@ -89,10 +95,10 @@ try:
89
  llava_model.eval()
90
  llava_mode = "local"
91
  print("βœ… Llava loaded from installed package.")
92
- except Exception as e_local:
93
  tb_local = traceback.format_exc()
94
  load_errors.append(("local_llava_import", tb_local))
95
- print("Local llava import failed β€” will try trust_remote_code fallback. (see logs)")
96
 
97
  # Attempt 2: trust_remote_code fallback
98
  if llava_mode is None:
@@ -110,33 +116,30 @@ if llava_mode is None:
110
  llava_model.eval()
111
  llava_mode = "trust_remote_code"
112
  print("βœ… Llava loaded via trust_remote_code fallback.")
113
- except Exception as e_trust:
114
  tb_trust = traceback.format_exc()
115
  load_errors.append(("fallback_trust_remote_code", tb_trust))
116
- print("trust_remote_code fallback failed.")
117
 
118
- # Attempt 3: Hugging Face Inference API fallback (requires HUGGINGFACE_TOKEN)
119
  if llava_mode is None and HUGGINGFACE_TOKEN:
120
- # we won't load a model locally; will call inference API for generation
121
  llava_mode = "hf_api"
122
- print("No local model available. Will use Hugging Face Inference API for generation (HUGGINGFACE_TOKEN detected).")
123
 
124
- # If still no method available, keep llava_mode None and continue β€” UI will show actionable message
125
  if llava_mode is None:
126
- print("WARNING: No Llava model available locally or via trust_remote_code, and no HUGGINGFACE_TOKEN found.")
127
- print("App will start but generation will return an actionable error. See load_errors for tracebacks.")
128
  for name, tb in load_errors:
129
- print(f"--- {name} traceback ---")
130
- print(tb)
131
 
132
  # -------------------------
133
- # Helper: call Hugging Face Inference API for text generation
134
  # -------------------------
135
  def call_hf_inference_api(prompt: str, max_new_tokens: int = 256, temperature: float = 0.0):
136
  if not HUGGINGFACE_TOKEN:
137
  raise RuntimeError("HUGGINGFACE_TOKEN not set; cannot call Hugging Face Inference API.")
138
- headers = {"Authorization": f"Bearer {HUGGINGFACE_TOKEN}"}
139
  payload = {
 
140
  "inputs": prompt,
141
  "parameters": {"max_new_tokens": max_new_tokens, "temperature": temperature},
142
  "options": {"wait_for_model": True},
@@ -145,19 +148,17 @@ def call_hf_inference_api(prompt: str, max_new_tokens: int = 256, temperature: f
145
  if resp.status_code != 200:
146
  raise RuntimeError(f"HF Inference API error {resp.status_code}: {resp.text}")
147
  data = resp.json()
148
- # API returns list or dict depending on model; handle common shapes
149
  if isinstance(data, list) and data and isinstance(data[0], dict) and "generated_text" in data[0]:
150
  return data[0]["generated_text"]
151
  if isinstance(data, dict) and "generated_text" in data:
152
  return data["generated_text"]
153
- # If the model returns a plain string or other structure:
154
  if isinstance(data, str):
155
  return data
156
- # Fallback: try to stringify
157
  return json.dumps(data)
158
 
159
  # -------------------------
160
- # Retrieval & generation functions
161
  # -------------------------
162
  def retrieve_top_k_texts(image: Image.Image, k: int = TOP_K_DEFAULT):
163
  inputs = siglip_processor(images=image, return_tensors="pt")
@@ -180,7 +181,6 @@ def llava_answer(image: Image.Image, retrieved_texts, question: str, max_tokens:
180
  )
181
 
182
  if llava_mode in ("local", "trust_remote_code"):
183
- # use the tokenizer + local model
184
  inputs = llava_tokenizer(prompt, return_tensors="pt")
185
  inputs = {k: v.to(device) for k, v in inputs.items()}
186
  with torch.no_grad():
@@ -188,28 +188,24 @@ def llava_answer(image: Image.Image, retrieved_texts, question: str, max_tokens:
188
  resp = llava_tokenizer.decode(output_ids[0], skip_special_tokens=True)
189
  return resp
190
  elif llava_mode == "hf_api":
191
- # Use HF Inference API
192
  return call_hf_inference_api(prompt, max_new_tokens=max_tokens)
193
  else:
194
- # No model available β€” return actionable error for the UI
195
  err = (
196
  "No Llava model is available for generation.\n\n"
197
- "Options to fix:\n"
198
  "1) Install the LLaVA repo in requirements.txt and rebuild the Space:\n"
199
  " git+https://github.com/haotian-liu/LLaVA.git@main\n"
200
- "2) Or provide a Hugging Face API token as the HUGGINGFACE_TOKEN secret in Space settings so the app can\n"
201
- f" fall back to the Inference API. Expected token env var name: HUGGINGFACE_TOKEN\n\n"
202
- "Debug info (tracebacks were printed to Space logs at startup).\n"
203
  )
204
  return err
205
 
206
  # -------------------------
207
- # Gradio pipeline + UI
208
  # -------------------------
209
  def gradio_pipeline(image: Image.Image, question: str, k: int = TOP_K_DEFAULT):
210
  if image is None or not question:
211
  return None, "Please provide both an image and a question."
212
-
213
  retrieved = retrieve_top_k_texts(image, k=int(k))
214
  try:
215
  answer = llava_answer(image, retrieved, question)
@@ -221,8 +217,8 @@ def gradio_pipeline(image: Image.Image, question: str, k: int = TOP_K_DEFAULT):
221
  with gr.Blocks(title="Agri Image + Question β†’ Llava Response (robust)") as demo:
222
  gr.Markdown(
223
  "## Agri Image QA\n\nThis app preloads SigLip embeddings at startup. "
224
- "Generation uses a local Llava model if available, otherwise the Hugging Face Inference API "
225
- "(requires HUGGINGFACE_TOKEN set in Space secrets)."
226
  )
227
  with gr.Row():
228
  img_in = gr.Image(type="pil")
 
1
+ # app.py β€” Robust CPU-friendly SigLip -> (Llava local | trust_remote_code | HF router) pipeline
2
  import os
3
  # Force CPU before importing torch/transformers if you want CPU-only
4
  os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
 
6
 
7
  import sys
8
  import traceback
 
9
  import json
10
+ from typing import List, Optional
 
11
 
12
+ import requests
13
  import torch
14
  import torch.nn.functional as F
15
  from datasets import load_dataset
16
+ from transformers import (
17
+ AutoProcessor,
18
+ AutoModel,
19
+ AutoTokenizer,
20
+ AutoModelForCausalLM,
21
+ )
22
  from PIL import Image
23
  import gradio as gr
24
  from tqdm import tqdm
25
 
26
  # -------------------------
27
+ # Config - update these IDs as needed
28
  # -------------------------
29
  SIGLIP_MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned"
30
  LLAVA_MODEL_ID = "llava-hf/llava-1.5-7b-hf" # change if needed
31
  DATASET_TEMPLATE = "EYEDOL/AGRILLAVA-image-text{}"
32
+ NUM_DATASETS = 1 # set to 15 if you want all datasets (startup memory/time increases)
33
  BATCH_SIZE = 16
34
  TOP_K_DEFAULT = 3
35
+
36
+ # Hugging Face router endpoint (new inference endpoint)
37
+ HF_API_URL = "https://router.huggingface.co/hf-inference"
38
  HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", None)
39
 
40
+ # Device - CPU only
41
  device = torch.device("cpu")
42
+ print("Running on device:", device)
43
 
44
  # -------------------------
45
+ # Load dataset and SigLip model & precompute text embeddings at startup
46
  # -------------------------
47
  print("Loading datasets and computing SigLip text embeddings (startup)...")
48
  texts_all: List[str] = []
 
54
  siglip_model = AutoModel.from_pretrained(SIGLIP_MODEL_ID).to(device)
55
  siglip_model.eval()
56
 
57
+ # Precompute text embeddings (on CPU)
58
  text_embeds_parts = []
59
+ for i in tqdm(range(0, len(texts_all), BATCH_SIZE), desc="Encoding texts (CPU)"):
60
  batch_texts = texts_all[i : i + BATCH_SIZE]
61
  inputs = siglip_processor(text=batch_texts, padding=True, truncation=True, return_tensors="pt")
62
  with torch.no_grad():
 
75
  # -------------------------
76
  llava_tokenizer: Optional[AutoTokenizer] = None
77
  llava_model = None
78
+ llava_mode: Optional[str] = None # 'local', 'trust_remote_code', 'hf_api', or None
79
  load_errors = []
80
 
81
  # Attempt 1: local llava package (preferred)
 
95
  llava_model.eval()
96
  llava_mode = "local"
97
  print("βœ… Llava loaded from installed package.")
98
+ except Exception:
99
  tb_local = traceback.format_exc()
100
  load_errors.append(("local_llava_import", tb_local))
101
+ print("Local llava import failed β€” will try trust_remote_code fallback. See logs for details.")
102
 
103
  # Attempt 2: trust_remote_code fallback
104
  if llava_mode is None:
 
116
  llava_model.eval()
117
  llava_mode = "trust_remote_code"
118
  print("βœ… Llava loaded via trust_remote_code fallback.")
119
+ except Exception:
120
  tb_trust = traceback.format_exc()
121
  load_errors.append(("fallback_trust_remote_code", tb_trust))
122
+ print("trust_remote_code fallback failed β€” will try HF router if token provided.")
123
 
124
+ # Attempt 3: Hugging Face router Inference API fallback (requires HUGGINGFACE_TOKEN)
125
  if llava_mode is None and HUGGINGFACE_TOKEN:
 
126
  llava_mode = "hf_api"
127
+ print("No usable local model found. Will use Hugging Face router Inference API for generation (HUGGINGFACE_TOKEN detected).")
128
 
 
129
  if llava_mode is None:
130
+ print("WARNING: No Llava model available and no HUGGINGFACE_TOKEN supplied. Generation will return an actionable error.")
 
131
  for name, tb in load_errors:
132
+ print(f"--- {name} traceback ---\n{tb}")
 
133
 
134
  # -------------------------
135
+ # Helper: call Hugging Face router inference API
136
  # -------------------------
137
  def call_hf_inference_api(prompt: str, max_new_tokens: int = 256, temperature: float = 0.0):
138
  if not HUGGINGFACE_TOKEN:
139
  raise RuntimeError("HUGGINGFACE_TOKEN not set; cannot call Hugging Face Inference API.")
140
+ headers = {"Authorization": f"Bearer {HUGGINGFACE_TOKEN}", "Content-Type": "application/json"}
141
  payload = {
142
+ "model": LLAVA_MODEL_ID,
143
  "inputs": prompt,
144
  "parameters": {"max_new_tokens": max_new_tokens, "temperature": temperature},
145
  "options": {"wait_for_model": True},
 
148
  if resp.status_code != 200:
149
  raise RuntimeError(f"HF Inference API error {resp.status_code}: {resp.text}")
150
  data = resp.json()
151
+ # handle common response shapes
152
  if isinstance(data, list) and data and isinstance(data[0], dict) and "generated_text" in data[0]:
153
  return data[0]["generated_text"]
154
  if isinstance(data, dict) and "generated_text" in data:
155
  return data["generated_text"]
 
156
  if isinstance(data, str):
157
  return data
 
158
  return json.dumps(data)
159
 
160
  # -------------------------
161
+ # Retrieval & generation
162
  # -------------------------
163
  def retrieve_top_k_texts(image: Image.Image, k: int = TOP_K_DEFAULT):
164
  inputs = siglip_processor(images=image, return_tensors="pt")
 
181
  )
182
 
183
  if llava_mode in ("local", "trust_remote_code"):
 
184
  inputs = llava_tokenizer(prompt, return_tensors="pt")
185
  inputs = {k: v.to(device) for k, v in inputs.items()}
186
  with torch.no_grad():
 
188
  resp = llava_tokenizer.decode(output_ids[0], skip_special_tokens=True)
189
  return resp
190
  elif llava_mode == "hf_api":
 
191
  return call_hf_inference_api(prompt, max_new_tokens=max_tokens)
192
  else:
 
193
  err = (
194
  "No Llava model is available for generation.\n\n"
195
+ "Fix options:\n"
196
  "1) Install the LLaVA repo in requirements.txt and rebuild the Space:\n"
197
  " git+https://github.com/haotian-liu/LLaVA.git@main\n"
198
+ "2) Or add a valid Hugging Face API token as HUGGINGFACE_TOKEN in Space secrets to use the router.\n\n"
199
+ "Check Space logs for detailed tracebacks printed at startup."
 
200
  )
201
  return err
202
 
203
  # -------------------------
204
+ # Gradio app
205
  # -------------------------
206
  def gradio_pipeline(image: Image.Image, question: str, k: int = TOP_K_DEFAULT):
207
  if image is None or not question:
208
  return None, "Please provide both an image and a question."
 
209
  retrieved = retrieve_top_k_texts(image, k=int(k))
210
  try:
211
  answer = llava_answer(image, retrieved, question)
 
217
  with gr.Blocks(title="Agri Image + Question β†’ Llava Response (robust)") as demo:
218
  gr.Markdown(
219
  "## Agri Image QA\n\nThis app preloads SigLip embeddings at startup. "
220
+ "Generation uses a local Llava model if available, otherwise the Hugging Face router Inference API "
221
+ "(requires HUGGINGFACE_TOKEN secret in Space settings)."
222
  )
223
  with gr.Row():
224
  img_in = gr.Image(type="pil")