EYEDOL commited on
Commit
1278acb
·
verified ·
1 Parent(s): 606744f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -43
app.py CHANGED
@@ -1,25 +1,28 @@
1
- # app.py (CPU-only version)
2
  import os
3
- # FORCE CPU: disable CUDA visibility for this process before importing torch/transformers
4
- os.environ["CUDA_VISIBLE_DEVICES"] = "" # important: must be set before torch import
5
  os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
6
 
 
 
 
 
7
  import torch
8
  import torch.nn.functional as F
9
  from datasets import load_dataset
10
- from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM
11
  from transformers import AutoProcessor, AutoModel, AutoTokenizer, AutoModelForCausalLM
12
  from PIL import Image
13
  import gradio as gr
14
  from tqdm import tqdm
15
 
16
  # -------------------------
17
- # Config - set your model IDs here
18
  # -------------------------
19
  SIGLIP_MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned"
20
- LLAVA_MODEL_ID = "llava-hf/llava-1.5-7b-hf" # <-- replace this with the HF repo ID
21
  DATASET_TEMPLATE = "EYEDOL/AGRILLAVA-image-text{}"
22
- NUM_DATASETS = 1
23
  BATCH_SIZE = 16
24
  TOP_K_DEFAULT = 3
25
 
@@ -28,56 +31,104 @@ device = torch.device("cpu")
28
  print("Running on device:", device)
29
 
30
  # -------------------------
31
- # Load dataset and SigLip (as before)
32
  # -------------------------
33
  print("Loading datasets and computing SigLip text embeddings (CPU)...")
34
- texts_all = []
35
  for i in range(1, NUM_DATASETS + 1):
36
  ds = load_dataset(DATASET_TEMPLATE.format(i), split="train")
37
  texts_all.extend(ds["text"])
38
 
39
  siglip_processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_ID)
40
- # Use AutoModel for Siglip (same as before)
41
  siglip_model = AutoModel.from_pretrained(SIGLIP_MODEL_ID).to(device)
42
  siglip_model.eval()
43
 
44
- # Precompute text embeddings (on CPU) -- this may take time
45
- text_embeds_all = []
46
  for i in tqdm(range(0, len(texts_all), BATCH_SIZE), desc="Encoding texts (CPU)"):
47
  batch_texts = texts_all[i : i + BATCH_SIZE]
48
  inputs = siglip_processor(text=batch_texts, padding=True, truncation=True, return_tensors="pt")
49
- # ensure tensors are on CPU (they already are)
50
  with torch.no_grad():
51
  text_embeds = siglip_model.get_text_features(**inputs)
52
  text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
53
- text_embeds_all.append(text_embeds.cpu())
54
  del inputs, text_embeds
55
- text_embeds_all = torch.cat(text_embeds_all, dim=0)
 
 
 
56
  print(f"Finished encoding {len(texts_all)} texts. Embeddings shape: {text_embeds_all.shape}")
57
 
58
  # -------------------------
59
- # Load Llava tokenizer + model on CPU
 
 
 
 
60
  # -------------------------
61
- print("Loading Llava tokenizer and model (CPU, trust_remote_code=True)...")
62
- # Use slow tokenizer if fast fails on Spaces
63
- llava_tokenizer = AutoTokenizer.from_pretrained(LLAVA_MODEL_ID, use_fast=False)
64
-
65
- # Use trust_remote_code=True so the repo's custom model class is used.
66
- # Use device_map={"": "cpu"} to force all model weights to CPU; use torch_dtype=float32 for safety.
67
- llava_model = AutoModelForCausalLM.from_pretrained(
68
- LLAVA_MODEL_ID,
69
- trust_remote_code=True,
70
- device_map={"": "cpu"},
71
- torch_dtype=torch.float32,
72
- low_cpu_mem_usage=True # help reduce RAM usage when possible
73
- )
74
- llava_model.eval()
75
- print("Llava model loaded onto CPU.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  # -------------------------
78
- # Retrieval and answer functions
79
  # -------------------------
80
- def retrieve_top_k_texts(image: Image.Image, k=TOP_K_DEFAULT):
81
  inputs = siglip_processor(images=image, return_tensors="pt")
82
  with torch.no_grad():
83
  img_embed = siglip_model.get_image_features(**inputs)
@@ -88,17 +139,26 @@ def retrieve_top_k_texts(image: Image.Image, k=TOP_K_DEFAULT):
88
  results = [(texts_all[idx.item()], float(score)) for idx, score in zip(topk.indices, topk.values)]
89
  return results
90
 
91
- def llava_answer(image: Image.Image, retrieved_texts, question: str, max_tokens=256):
 
 
 
 
92
  context_text = "\n".join([f"Retrieved Text: {t}" for t, _ in retrieved_texts])
93
- prompt = f"Given the image and the following texts:\n{context_text}\nUser Question: {question}\nProvide a detailed answer and crop suggestions."
 
 
 
 
 
94
 
95
  inputs = llava_tokenizer(prompt, return_tensors="pt")
96
- # ensure inputs are on CPU
97
- inputs = {k: v.to("cpu") for k, v in inputs.items()}
98
  with torch.no_grad():
99
- out = llava_model.generate(**inputs, max_new_tokens=max_tokens)
100
- resp = llava_tokenizer.decode(out[0], skip_special_tokens=True)
101
- return resp
102
 
103
  # -------------------------
104
  # Gradio pipeline
@@ -107,17 +167,27 @@ def gradio_pipeline(image: Image.Image, question: str, k: int = TOP_K_DEFAULT):
107
  if image is None or not question:
108
  return None, "Please provide both an image and a question."
109
  retrieved = retrieve_top_k_texts(image, k=int(k))
110
- answer = llava_answer(image, retrieved, question)
 
 
 
 
111
  return image, answer
112
 
 
 
 
113
  with gr.Blocks(title="Agri Image + Question → Llava Response (CPU)") as demo:
114
- gr.Markdown("# Agri Image QA (CPU)\\nUpload an agriculture image + question. This runs fully on CPU.")
 
 
 
115
  with gr.Row():
116
  img_in = gr.Image(type="pil")
117
  out_img = gr.Image(type="pil", label="Image")
118
  question_input = gr.Textbox(label="Question about the image", lines=2)
119
  k_slider = gr.Slider(minimum=1, maximum=10, step=1, value=TOP_K_DEFAULT, label="Top-k retrieval")
120
- txt_out = gr.Textbox(label="Llava Response", lines=8)
121
  run_btn = gr.Button("Generate Answer")
122
  run_btn.click(fn=gradio_pipeline, inputs=[img_in, question_input, k_slider], outputs=[out_img, txt_out])
123
 
 
1
+ # app.py (CPU-friendly, preloaded SigLip + Llava with robust loading)
2
  import os
3
+ # FORCE CPU: must be set before importing torch/transformers
4
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
5
  os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
6
 
7
+ import sys
8
+ import traceback
9
+ from typing import List, Tuple
10
+
11
  import torch
12
  import torch.nn.functional as F
13
  from datasets import load_dataset
 
14
  from transformers import AutoProcessor, AutoModel, AutoTokenizer, AutoModelForCausalLM
15
  from PIL import Image
16
  import gradio as gr
17
  from tqdm import tqdm
18
 
19
  # -------------------------
20
+ # Config - update these IDs as needed
21
  # -------------------------
22
  SIGLIP_MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned"
23
+ LLAVA_MODEL_ID = "llava-hf/llava-1.5-7b-hf" # <-- replace with your HF repo ID if different
24
  DATASET_TEMPLATE = "EYEDOL/AGRILLAVA-image-text{}"
25
+ NUM_DATASETS = 1 # set to 15 if you want all datasets loaded (startup memory/time increases)
26
  BATCH_SIZE = 16
27
  TOP_K_DEFAULT = 3
28
 
 
31
  print("Running on device:", device)
32
 
33
  # -------------------------
34
+ # Load dataset and SigLip
35
  # -------------------------
36
  print("Loading datasets and computing SigLip text embeddings (CPU)...")
37
+ texts_all: List[str] = []
38
  for i in range(1, NUM_DATASETS + 1):
39
  ds = load_dataset(DATASET_TEMPLATE.format(i), split="train")
40
  texts_all.extend(ds["text"])
41
 
42
  siglip_processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_ID)
 
43
  siglip_model = AutoModel.from_pretrained(SIGLIP_MODEL_ID).to(device)
44
  siglip_model.eval()
45
 
46
+ # Precompute text embeddings (CPU)
47
+ text_embeds_list = []
48
  for i in tqdm(range(0, len(texts_all), BATCH_SIZE), desc="Encoding texts (CPU)"):
49
  batch_texts = texts_all[i : i + BATCH_SIZE]
50
  inputs = siglip_processor(text=batch_texts, padding=True, truncation=True, return_tensors="pt")
51
+ # inputs are on CPU
52
  with torch.no_grad():
53
  text_embeds = siglip_model.get_text_features(**inputs)
54
  text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
55
+ text_embeds_list.append(text_embeds.cpu())
56
  del inputs, text_embeds
57
+ if len(text_embeds_list) == 0:
58
+ text_embeds_all = torch.empty((0, 0))
59
+ else:
60
+ text_embeds_all = torch.cat(text_embeds_list, dim=0)
61
  print(f"Finished encoding {len(texts_all)} texts. Embeddings shape: {text_embeds_all.shape}")
62
 
63
  # -------------------------
64
+ # Load Llava model & tokenizer (robust)
65
+ # Strategy:
66
+ # 1) Try to import LlavaForCausalLM from installed llava package (recommended).
67
+ # 2) If not available, try AutoModelForCausalLM.from_pretrained(..., trust_remote_code=True).
68
+ # 3) If both fail, raise a clear error with instructions.
69
  # -------------------------
70
+ llava_tokenizer = None
71
+ llava_model = None
72
+
73
+ load_errors = []
74
+
75
+ # Attempt 1: local llava package (preferred)
76
+ try:
77
+ # Import here so we don't require the package unless we need it
78
+ from llava.model import LlavaForCausalLM # type: ignore
79
+
80
+ print("Found installed 'llava' package — loading LlavaForCausalLM from it (CPU)...")
81
+ llava_tokenizer = AutoTokenizer.from_pretrained(LLAVA_MODEL_ID, use_fast=False)
82
+ llava_model = LlavaForCausalLM.from_pretrained(
83
+ LLAVA_MODEL_ID,
84
+ device_map={"": "cpu"},
85
+ torch_dtype=torch.float32,
86
+ low_cpu_mem_usage=True,
87
+ )
88
+ llava_model.to(device)
89
+ llava_model.eval()
90
+ print("✅ LlavaForCausalLM loaded via local llava package.")
91
+ except Exception as e_local:
92
+ tb_local = traceback.format_exc()
93
+ load_errors.append(("local_llava_import", tb_local))
94
+ print("Local llava import/load failed — will attempt fallback (trust_remote_code=True).")
95
+ # Attempt 2: trust_remote_code fallback
96
+ try:
97
+ print("Attempting AutoModelForCausalLM.from_pretrained(..., trust_remote_code=True) (CPU)...")
98
+ llava_tokenizer = AutoTokenizer.from_pretrained(LLAVA_MODEL_ID, use_fast=False)
99
+ llava_model = AutoModelForCausalLM.from_pretrained(
100
+ LLAVA_MODEL_ID,
101
+ trust_remote_code=True,
102
+ device_map={"": "cpu"},
103
+ torch_dtype=torch.float32,
104
+ low_cpu_mem_usage=True,
105
+ )
106
+ llava_model.to(device)
107
+ llava_model.eval()
108
+ print("✅ Llava model loaded via trust_remote_code fallback.")
109
+ except Exception as e_fallback:
110
+ tb_fallback = traceback.format_exc()
111
+ load_errors.append(("fallback_trust_remote_code", tb_fallback))
112
+ # Both failed — raise a helpful error describing how to fix
113
+ err_msg = (
114
+ "Failed to load the Llava model using both strategies.\n\n"
115
+ "Recommended fixes:\n"
116
+ "1) Add the LLaVA repo to requirements.txt so the `llava` package and LlavaForCausalLM are installed:\n"
117
+ " git+https://github.com/haotian-liu/LLaVA.git@main\n"
118
+ " Then rebuild your Space.\n\n"
119
+ "2) If you prefer trust_remote_code, ensure the HF model repo supports `trust_remote_code=True` and\n"
120
+ " that any repo-specific dependencies (listed in the repo README) are installed in requirements.txt.\n\n"
121
+ "Debug details (tracebacks):\n\n"
122
+ )
123
+ for name, tb in load_errors:
124
+ err_msg += f"--- {name} traceback ---\n{tb}\n"
125
+ # raise RuntimeError with the composed message
126
+ raise RuntimeError(err_msg)
127
 
128
  # -------------------------
129
+ # SigLip retrieval function
130
  # -------------------------
131
+ def retrieve_top_k_texts(image: Image.Image, k: int = TOP_K_DEFAULT):
132
  inputs = siglip_processor(images=image, return_tensors="pt")
133
  with torch.no_grad():
134
  img_embed = siglip_model.get_image_features(**inputs)
 
139
  results = [(texts_all[idx.item()], float(score)) for idx, score in zip(topk.indices, topk.values)]
140
  return results
141
 
142
+ # -------------------------
143
+ # Llava answer function
144
+ # -------------------------
145
+ def llava_answer(image: Image.Image, retrieved_texts, question: str, max_tokens: int = 256):
146
+ # Compose context: retrieved text + short instruction
147
  context_text = "\n".join([f"Retrieved Text: {t}" for t, _ in retrieved_texts])
148
+ prompt = (
149
+ "You are an agricultural assistant. Use the provided retrieved texts and the image context to answer the user's question.\n\n"
150
+ f"Retrieved texts:\n{context_text}\n\n"
151
+ f"User question: {question}\n\n"
152
+ "Provide a concise, actionable answer and crop suggestions where appropriate."
153
+ )
154
 
155
  inputs = llava_tokenizer(prompt, return_tensors="pt")
156
+ # ensure tokens are on CPU
157
+ inputs = {k: v.to(device) for k, v in inputs.items()}
158
  with torch.no_grad():
159
+ output_ids = llava_model.generate(**inputs, max_new_tokens=max_tokens)
160
+ response = llava_tokenizer.decode(output_ids[0], skip_special_tokens=True)
161
+ return response
162
 
163
  # -------------------------
164
  # Gradio pipeline
 
167
  if image is None or not question:
168
  return None, "Please provide both an image and a question."
169
  retrieved = retrieve_top_k_texts(image, k=int(k))
170
+ try:
171
+ answer = llava_answer(image, retrieved, question)
172
+ except Exception as e:
173
+ tb = traceback.format_exc()
174
+ answer = f"Error while generating answer: {e}\n\nTraceback:\n{tb}"
175
  return image, answer
176
 
177
+ # -------------------------
178
+ # Gradio app
179
+ # -------------------------
180
  with gr.Blocks(title="Agri Image + Question → Llava Response (CPU)") as demo:
181
+ gr.Markdown(
182
+ "# Agri Image QA (CPU)\n\nUpload an agriculture image and ask a question. "
183
+ "This Space preloads models and embeddings at startup for faster responses."
184
+ )
185
  with gr.Row():
186
  img_in = gr.Image(type="pil")
187
  out_img = gr.Image(type="pil", label="Image")
188
  question_input = gr.Textbox(label="Question about the image", lines=2)
189
  k_slider = gr.Slider(minimum=1, maximum=10, step=1, value=TOP_K_DEFAULT, label="Top-k retrieval")
190
+ txt_out = gr.Textbox(label="Llava Response", lines=12)
191
  run_btn = gr.Button("Generate Answer")
192
  run_btn.click(fn=gradio_pipeline, inputs=[img_in, question_input, k_slider], outputs=[out_img, txt_out])
193