EYEDOL commited on
Commit
463c8c8
·
verified ·
1 Parent(s): e3c600e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -56
app.py CHANGED
@@ -1,77 +1,83 @@
1
- """
2
- Gradio Space app: Preloaded SigLip + Llava pipeline for instant user response.
3
- Pipeline:
4
- 1. Startup: load SigLip processor, model, compute all text embeddings.
5
- 2. Startup: load Llava tokenizer & LlavaForCausalLM model.
6
- 3. User uploads image + asks question → instant retrieval + Llava response.
7
- """
8
-
9
  import os
10
- from typing import List, Tuple
 
 
11
 
12
- import gradio as gr
13
  import torch
14
  import torch.nn.functional as F
15
  from datasets import load_dataset
 
16
  from PIL import Image
17
- from transformers import AutoProcessor
18
-
19
- # Install llava repo if not already installed:
20
- # pip install git+https://github.com/haotian-liu/LLaVA.git
21
- from llava.model import LlavaForCausalLM
22
- from transformers import AutoTokenizer
23
 
 
 
 
24
  SIGLIP_MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned"
25
- LLAVA_MODEL_ID = "llava-hf/llava-1.5-7b-hf" # replace with your actual model repo
26
  DATASET_TEMPLATE = "EYEDOL/AGRILLAVA-image-text{}"
27
  NUM_DATASETS = 1
28
  BATCH_SIZE = 16
29
  TOP_K_DEFAULT = 3
30
 
31
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
32
 
33
  # -------------------------
34
- # Startup: load datasets and compute SigLip text embeddings
35
  # -------------------------
36
- print("Loading datasets and computing SigLip text embeddings...")
37
  texts_all = []
38
  for i in range(1, NUM_DATASETS + 1):
39
  ds = load_dataset(DATASET_TEMPLATE.format(i), split="train")
40
  texts_all.extend(ds["text"])
41
 
42
  siglip_processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_ID)
 
43
  siglip_model = AutoModel.from_pretrained(SIGLIP_MODEL_ID).to(device)
44
  siglip_model.eval()
45
 
 
46
  text_embeds_all = []
47
- for i in range(0, len(texts_all), BATCH_SIZE):
48
- batch_texts = texts_all[i:i+BATCH_SIZE]
49
- inputs = siglip_processor(text=batch_texts, padding=True, truncation=True, return_tensors="pt").to(device)
 
50
  with torch.no_grad():
51
  text_embeds = siglip_model.get_text_features(**inputs)
52
  text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
53
  text_embeds_all.append(text_embeds.cpu())
54
  del inputs, text_embeds
55
- torch.cuda.empty_cache()
56
-
57
  text_embeds_all = torch.cat(text_embeds_all, dim=0)
58
- print(f"Finished encoding {len(texts_all)} texts. Shape: {text_embeds_all.shape}")
59
 
60
  # -------------------------
61
- # Startup: load Llava model & tokenizer
62
  # -------------------------
63
- print("Loading Llava tokenizer and LlavaForCausalLM model...")
 
64
  llava_tokenizer = AutoTokenizer.from_pretrained(LLAVA_MODEL_ID, use_fast=False)
65
- llava_model = LlavaForCausalLM.from_pretrained(LLAVA_MODEL_ID).to(device)
 
 
 
 
 
 
 
 
 
66
  llava_model.eval()
67
- print("Llava model loaded.")
68
 
69
  # -------------------------
70
- # SigLip retrieval function
71
  # -------------------------
72
-
73
  def retrieve_top_k_texts(image: Image.Image, k=TOP_K_DEFAULT):
74
- inputs = siglip_processor(images=image, return_tensors="pt").to(device)
75
  with torch.no_grad():
76
  img_embed = siglip_model.get_image_features(**inputs)
77
  img_embed = img_embed / img_embed.norm(p=2, dim=-1, keepdim=True)
@@ -81,37 +87,30 @@ def retrieve_top_k_texts(image: Image.Image, k=TOP_K_DEFAULT):
81
  results = [(texts_all[idx.item()], float(score)) for idx, score in zip(topk.indices, topk.values)]
82
  return results
83
 
84
- # -------------------------
85
- # Llava answer function
86
- # -------------------------
87
-
88
- def llava_answer(image: Image.Image, retrieved_texts: List[str], question: str, max_tokens=256):
89
  context_text = "\n".join([f"Retrieved Text: {t}" for t, _ in retrieved_texts])
90
  prompt = f"Given the image and the following texts:\n{context_text}\nUser Question: {question}\nProvide a detailed answer and crop suggestions."
91
 
92
- inputs = llava_tokenizer(prompt, return_tensors="pt").to(device)
 
 
93
  with torch.no_grad():
94
- output_ids = llava_model.generate(**inputs, max_new_tokens=max_tokens)
95
- response = llava_tokenizer.decode(output_ids[0], skip_special_tokens=True)
96
- return response
97
 
98
  # -------------------------
99
- # Gradio interface pipeline
100
  # -------------------------
101
-
102
  def gradio_pipeline(image: Image.Image, question: str, k: int = TOP_K_DEFAULT):
103
  if image is None or not question:
104
- return None, "Please provide both image and question."
 
 
 
105
 
106
- retrieved_texts = retrieve_top_k_texts(image, k=int(k))
107
- response = llava_answer(image, retrieved_texts, question)
108
- return image, response
109
-
110
- # -------------------------
111
- # Gradio Blocks
112
- # -------------------------
113
- with gr.Blocks(title="Agri Image + Question → Llava Response") as demo:
114
- gr.Markdown("# Agri Image Question Answering\nUpload an agriculture image, ask a question, and get context-aware crop suggestions.")
115
  with gr.Row():
116
  img_in = gr.Image(type="pil")
117
  out_img = gr.Image(type="pil", label="Image")
@@ -119,8 +118,7 @@ with gr.Blocks(title="Agri Image + Question → Llava Response") as demo:
119
  k_slider = gr.Slider(minimum=1, maximum=10, step=1, value=TOP_K_DEFAULT, label="Top-k retrieval")
120
  txt_out = gr.Textbox(label="Llava Response", lines=8)
121
  run_btn = gr.Button("Generate Answer")
122
-
123
  run_btn.click(fn=gradio_pipeline, inputs=[img_in, question_input, k_slider], outputs=[out_img, txt_out])
124
 
125
  if __name__ == "__main__":
126
- demo.launch(server_name="0.0.0.0", share=False)
 
1
+ # app.py (CPU-only version)
 
 
 
 
 
 
 
2
  import os
3
+ # FORCE CPU: disable CUDA visibility for this process before importing torch/transformers
4
+ os.environ["CUDA_VISIBLE_DEVICES"] = "" # important: must be set before torch import
5
+ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
6
 
 
7
  import torch
8
  import torch.nn.functional as F
9
  from datasets import load_dataset
10
+ from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM
11
  from PIL import Image
12
+ import gradio as gr
13
+ from tqdm import tqdm
 
 
 
 
14
 
15
+ # -------------------------
16
+ # Config - set your model IDs here
17
+ # -------------------------
18
  SIGLIP_MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned"
19
+ LLAVA_MODEL_ID = "llava-hf/llava-1.5-7b-hf" # <-- replace this with the HF repo ID
20
  DATASET_TEMPLATE = "EYEDOL/AGRILLAVA-image-text{}"
21
  NUM_DATASETS = 1
22
  BATCH_SIZE = 16
23
  TOP_K_DEFAULT = 3
24
 
25
+ # Device - CPU only
26
+ device = torch.device("cpu")
27
+ print("Running on device:", device)
28
 
29
  # -------------------------
30
+ # Load dataset and SigLip (as before)
31
  # -------------------------
32
+ print("Loading datasets and computing SigLip text embeddings (CPU)...")
33
  texts_all = []
34
  for i in range(1, NUM_DATASETS + 1):
35
  ds = load_dataset(DATASET_TEMPLATE.format(i), split="train")
36
  texts_all.extend(ds["text"])
37
 
38
  siglip_processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_ID)
39
+ # Use AutoModel for Siglip (same as before)
40
  siglip_model = AutoModel.from_pretrained(SIGLIP_MODEL_ID).to(device)
41
  siglip_model.eval()
42
 
43
+ # Precompute text embeddings (on CPU) -- this may take time
44
  text_embeds_all = []
45
+ for i in tqdm(range(0, len(texts_all), BATCH_SIZE), desc="Encoding texts (CPU)"):
46
+ batch_texts = texts_all[i : i + BATCH_SIZE]
47
+ inputs = siglip_processor(text=batch_texts, padding=True, truncation=True, return_tensors="pt")
48
+ # ensure tensors are on CPU (they already are)
49
  with torch.no_grad():
50
  text_embeds = siglip_model.get_text_features(**inputs)
51
  text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
52
  text_embeds_all.append(text_embeds.cpu())
53
  del inputs, text_embeds
 
 
54
  text_embeds_all = torch.cat(text_embeds_all, dim=0)
55
+ print(f"Finished encoding {len(texts_all)} texts. Embeddings shape: {text_embeds_all.shape}")
56
 
57
  # -------------------------
58
+ # Load Llava tokenizer + model on CPU
59
  # -------------------------
60
+ print("Loading Llava tokenizer and model (CPU, trust_remote_code=True)...")
61
+ # Use slow tokenizer if fast fails on Spaces
62
  llava_tokenizer = AutoTokenizer.from_pretrained(LLAVA_MODEL_ID, use_fast=False)
63
+
64
+ # Use trust_remote_code=True so the repo's custom model class is used.
65
+ # Use device_map={"": "cpu"} to force all model weights to CPU; use torch_dtype=float32 for safety.
66
+ llava_model = AutoModelForCausalLM.from_pretrained(
67
+ LLAVA_MODEL_ID,
68
+ trust_remote_code=True,
69
+ device_map={"": "cpu"},
70
+ torch_dtype=torch.float32,
71
+ low_cpu_mem_usage=True # help reduce RAM usage when possible
72
+ )
73
  llava_model.eval()
74
+ print("Llava model loaded onto CPU.")
75
 
76
  # -------------------------
77
+ # Retrieval and answer functions
78
  # -------------------------
 
79
  def retrieve_top_k_texts(image: Image.Image, k=TOP_K_DEFAULT):
80
+ inputs = siglip_processor(images=image, return_tensors="pt")
81
  with torch.no_grad():
82
  img_embed = siglip_model.get_image_features(**inputs)
83
  img_embed = img_embed / img_embed.norm(p=2, dim=-1, keepdim=True)
 
87
  results = [(texts_all[idx.item()], float(score)) for idx, score in zip(topk.indices, topk.values)]
88
  return results
89
 
90
+ def llava_answer(image: Image.Image, retrieved_texts, question: str, max_tokens=256):
 
 
 
 
91
  context_text = "\n".join([f"Retrieved Text: {t}" for t, _ in retrieved_texts])
92
  prompt = f"Given the image and the following texts:\n{context_text}\nUser Question: {question}\nProvide a detailed answer and crop suggestions."
93
 
94
+ inputs = llava_tokenizer(prompt, return_tensors="pt")
95
+ # ensure inputs are on CPU
96
+ inputs = {k: v.to("cpu") for k, v in inputs.items()}
97
  with torch.no_grad():
98
+ out = llava_model.generate(**inputs, max_new_tokens=max_tokens)
99
+ resp = llava_tokenizer.decode(out[0], skip_special_tokens=True)
100
+ return resp
101
 
102
  # -------------------------
103
+ # Gradio pipeline
104
  # -------------------------
 
105
  def gradio_pipeline(image: Image.Image, question: str, k: int = TOP_K_DEFAULT):
106
  if image is None or not question:
107
+ return None, "Please provide both an image and a question."
108
+ retrieved = retrieve_top_k_texts(image, k=int(k))
109
+ answer = llava_answer(image, retrieved, question)
110
+ return image, answer
111
 
112
+ with gr.Blocks(title="Agri Image + Question → Llava Response (CPU)") as demo:
113
+ gr.Markdown("# Agri Image QA (CPU)\\nUpload an agriculture image + question. This runs fully on CPU.")
 
 
 
 
 
 
 
114
  with gr.Row():
115
  img_in = gr.Image(type="pil")
116
  out_img = gr.Image(type="pil", label="Image")
 
118
  k_slider = gr.Slider(minimum=1, maximum=10, step=1, value=TOP_K_DEFAULT, label="Top-k retrieval")
119
  txt_out = gr.Textbox(label="Llava Response", lines=8)
120
  run_btn = gr.Button("Generate Answer")
 
121
  run_btn.click(fn=gradio_pipeline, inputs=[img_in, question_input, k_slider], outputs=[out_img, txt_out])
122
 
123
  if __name__ == "__main__":
124
+ demo.launch(server_name="0.0.0.0", share=False)