EYEDOL commited on
Commit
0905af1
Β·
verified Β·
1 Parent(s): bcda1e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -67
app.py CHANGED
@@ -1,27 +1,21 @@
1
  """
2
- Gradio Space app (app.py) β€” SigLip Image + Question β†’ Llava Response (Improved)
3
 
4
- Pipeline and improvements:
5
- 1. User uploads an agriculture image and asks a question.
6
- 2. SigLip model retrieves top-k relevant texts.
7
- 3. Llava model generates a response using retrieved texts, image, and question.
8
-
9
- Improvements implemented to handle the Tokenizer/Model errors:
10
- - Lazy-load Llava model and tokenizer only when first required, reducing startup errors and memory usage.
11
- - Added exception handling for tokenizer/model loading failures (common with incompatible or custom Llava models).
12
- - Added clear error messages to guide installing correct dependencies or using compatible model versions.
13
  """
14
 
15
  import os
16
- from functools import lru_cache
17
  from typing import List, Tuple
18
 
19
  import gradio as gr
20
  import torch
21
  import torch.nn.functional as F
22
- from datasets import load_dataset
23
  from PIL import Image
24
- from transformers import AutoProcessor, AutoModel
25
  from tqdm import tqdm
26
 
27
  SIGLIP_MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned"
@@ -34,41 +28,49 @@ TOP_K_DEFAULT = 3
34
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
 
36
  # -------------------------
37
- # SigLip: load & precompute text embeddings
38
  # -------------------------
39
- @lru_cache(maxsize=1)
40
- def load_singlip_texts_and_embeddings():
41
- texts = []
42
- for i in range(1, NUM_DATASETS + 1):
43
- ds = load_dataset(DATASET_TEMPLATE.format(i), split="train")
44
- texts.extend(ds["text"])
45
-
46
- processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_ID)
47
- model = AutoModel.from_pretrained(SIGLIP_MODEL_ID).to(device)
48
- model.eval()
49
-
50
- text_embeds_all = []
51
- for i in tqdm(range(0, len(texts), BATCH_SIZE), desc="Encoding texts"):
52
- batch_texts = texts[i:i+BATCH_SIZE]
53
- inputs = processor(text=batch_texts, padding=True, truncation=True, return_tensors="pt").to(device)
54
- with torch.no_grad():
55
- text_embeds = model.get_text_features(**inputs)
56
- text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
57
- text_embeds_all.append(text_embeds.cpu())
58
- del inputs, text_embeds
59
- torch.cuda.empty_cache()
60
-
61
- text_embeds_all = torch.cat(text_embeds_all, dim=0)
62
- return processor, model, texts, text_embeds_all
 
 
 
 
 
 
 
 
63
 
64
  # -------------------------
65
  # SigLip retrieval
66
  # -------------------------
 
67
  def retrieve_top_k_texts(image: Image.Image, k=TOP_K_DEFAULT):
68
- processor, model, texts_all, text_embeds_all = load_singlip_texts_and_embeddings()
69
- inputs = processor(images=image, return_tensors="pt").to(device)
70
  with torch.no_grad():
71
- img_embed = model.get_image_features(**inputs)
72
  img_embed = img_embed / img_embed.norm(p=2, dim=-1, keepdim=True)
73
 
74
  sims = F.cosine_similarity(img_embed.cpu(), text_embeds_all)
@@ -77,34 +79,17 @@ def retrieve_top_k_texts(image: Image.Image, k=TOP_K_DEFAULT):
77
  return results
78
 
79
  # -------------------------
80
- # Lazy-load Llava model with error handling
81
  # -------------------------
82
- llava_model_cache = {}
83
-
84
- def load_llava_model():
85
- if 'model' in llava_model_cache and 'tokenizer' in llava_model_cache:
86
- return llava_model_cache['tokenizer'], llava_model_cache['model']
87
-
88
- try:
89
- from transformers import AutoModelForCausalLM, AutoTokenizer
90
- tokenizer = AutoTokenizer.from_pretrained(LLAVA_MODEL_ID)
91
- model = AutoModelForCausalLM.from_pretrained(LLAVA_MODEL_ID).to(device)
92
- model.eval()
93
- llava_model_cache['tokenizer'] = tokenizer
94
- llava_model_cache['model'] = model
95
- return tokenizer, model
96
- except Exception as e:
97
- raise RuntimeError(f"Failed to load Llava model/tokenizer: {e}. Ensure LLAVA_MODEL_ID is correct and compatible with transformers.")
98
 
99
  def llava_answer(image: Image.Image, retrieved_texts: List[str], question: str, max_tokens=256):
100
- tokenizer, model = load_llava_model()
101
  context_text = "\n".join([f"Retrieved Text: {t}" for t, _ in retrieved_texts])
102
  prompt = f"Given the image and the following texts:\n{context_text}\nUser Question: {question}\nProvide a detailed answer and crop suggestions."
103
 
104
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
105
  with torch.no_grad():
106
- output_ids = model.generate(**inputs, max_new_tokens=max_tokens)
107
- response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
108
  return response
109
 
110
  # -------------------------
@@ -116,10 +101,7 @@ def gradio_pipeline(image: Image.Image, question: str, k: int = TOP_K_DEFAULT):
116
  return None, "Please provide both image and question."
117
 
118
  retrieved_texts = retrieve_top_k_texts(image, k=int(k))
119
- try:
120
- response = llava_answer(image, retrieved_texts, question)
121
- except RuntimeError as e:
122
- response = str(e)
123
  return image, response
124
 
125
  with gr.Blocks(title="Agri Image + Question β†’ Llava Response") as demo:
@@ -135,4 +117,4 @@ with gr.Blocks(title="Agri Image + Question β†’ Llava Response") as demo:
135
  run_btn.click(fn=gradio_pipeline, inputs=[img_in, question_input, k_slider], outputs=[out_img, txt_out])
136
 
137
  if __name__ == "__main__":
138
- demo.launch(server_name="0.0.0.0", share=False)
 
1
  """
2
+ Gradio Space app (app.py) β€” Preloaded SigLip + Llava pipeline for instant response
3
 
4
+ Pipeline:
5
+ 1. At startup: load SigLip processor & model, compute all text embeddings.
6
+ 2. At startup: load Llava tokenizer & model.
7
+ 3. User uploads an image and asks a question β†’ pipeline uses preloaded resources for instant retrieval and response.
 
 
 
 
 
8
  """
9
 
10
  import os
 
11
  from typing import List, Tuple
12
 
13
  import gradio as gr
14
  import torch
15
  import torch.nn.functional as F
16
+ from datasets import load_dataset, concatenate_datasets
17
  from PIL import Image
18
+ from transformers import AutoProcessor, AutoModel, AutoTokenizer, AutoModelForCausalLM
19
  from tqdm import tqdm
20
 
21
  SIGLIP_MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned"
 
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
 
30
  # -------------------------
31
+ # Startup: load all datasets and compute text embeddings
32
  # -------------------------
33
+ print("⏳ Loading datasets and computing SigLip text embeddings...")
34
+ texts_all = []
35
+ for i in range(1, NUM_DATASETS + 1):
36
+ ds = load_dataset(DATASET_TEMPLATE.format(i), split="train")
37
+ texts_all.extend(ds["text"])
38
+
39
+ siglip_processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_ID)
40
+ siglip_model = AutoModel.from_pretrained(SIGLIP_MODEL_ID).to(device)
41
+ siglip_model.eval()
42
+
43
+ text_embeds_all = []
44
+ for i in tqdm(range(0, len(texts_all), BATCH_SIZE), desc="Encoding texts"):
45
+ batch_texts = texts_all[i:i+BATCH_SIZE]
46
+ inputs = siglip_processor(text=batch_texts, padding=True, truncation=True, return_tensors="pt").to(device)
47
+ with torch.no_grad():
48
+ text_embeds = siglip_model.get_text_features(**inputs)
49
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
50
+ text_embeds_all.append(text_embeds.cpu())
51
+ del inputs, text_embeds
52
+ torch.cuda.empty_cache()
53
+
54
+ text_embeds_all = torch.cat(text_embeds_all, dim=0)
55
+ print(f"βœ… Finished encoding {len(texts_all)} texts. Shape: {text_embeds_all.shape}")
56
+
57
+ # -------------------------
58
+ # Startup: load Llava model & tokenizer
59
+ # -------------------------
60
+ print("⏳ Loading Llava model and tokenizer...")
61
+ llava_tokenizer = AutoTokenizer.from_pretrained(LLAVA_MODEL_ID)
62
+ llava_model = AutoModelForCausalLM.from_pretrained(LLAVA_MODEL_ID).to(device)
63
+ llava_model.eval()
64
+ print("βœ… Llava model loaded.")
65
 
66
  # -------------------------
67
  # SigLip retrieval
68
  # -------------------------
69
+
70
  def retrieve_top_k_texts(image: Image.Image, k=TOP_K_DEFAULT):
71
+ inputs = siglip_processor(images=image, return_tensors="pt").to(device)
 
72
  with torch.no_grad():
73
+ img_embed = siglip_model.get_image_features(**inputs)
74
  img_embed = img_embed / img_embed.norm(p=2, dim=-1, keepdim=True)
75
 
76
  sims = F.cosine_similarity(img_embed.cpu(), text_embeds_all)
 
79
  return results
80
 
81
  # -------------------------
82
+ # Llava answer
83
  # -------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  def llava_answer(image: Image.Image, retrieved_texts: List[str], question: str, max_tokens=256):
 
86
  context_text = "\n".join([f"Retrieved Text: {t}" for t, _ in retrieved_texts])
87
  prompt = f"Given the image and the following texts:\n{context_text}\nUser Question: {question}\nProvide a detailed answer and crop suggestions."
88
 
89
+ inputs = llava_tokenizer(prompt, return_tensors="pt").to(device)
90
  with torch.no_grad():
91
+ output_ids = llava_model.generate(**inputs, max_new_tokens=max_tokens)
92
+ response = llava_tokenizer.decode(output_ids[0], skip_special_tokens=True)
93
  return response
94
 
95
  # -------------------------
 
101
  return None, "Please provide both image and question."
102
 
103
  retrieved_texts = retrieve_top_k_texts(image, k=int(k))
104
+ response = llava_answer(image, retrieved_texts, question)
 
 
 
105
  return image, response
106
 
107
  with gr.Blocks(title="Agri Image + Question β†’ Llava Response") as demo:
 
117
  run_btn.click(fn=gradio_pipeline, inputs=[img_in, question_input, k_slider], outputs=[out_img, txt_out])
118
 
119
  if __name__ == "__main__":
120
+ demo.launch(server_name="0.0.0.0", share=False)