EYEDOL commited on
Commit
7679f34
Β·
verified Β·
1 Parent(s): 43590b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -22
app.py CHANGED
@@ -1,14 +1,15 @@
1
  """
2
- Gradio Space app (app.py) β€” SigLip Image + Question β†’ Llava Response
3
 
4
- Pipeline:
5
- 1. User uploads an agriculture image.
6
- 2. User asks a question about the image.
7
- 3. SigLip model retrieves top-k text captions relevant to the image.
8
- 4. The retrieved text, original image, and user's question are sent to a Llava model.
9
- 5. Llava generates a context-aware response with crop suggestions or explanations.
10
 
11
- This updated app handles both the image retrieval and multi-modal question answering.
 
 
 
12
  """
13
 
14
  import os
@@ -23,17 +24,13 @@ from PIL import Image
23
  from transformers import AutoProcessor, AutoModel
24
  from tqdm import tqdm
25
 
26
- # -------------------------
27
- # Config
28
- # -------------------------
29
  SIGLIP_MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned"
30
- LLAVA_MODEL_ID = "llava-hf/llava-1.5-7b-hf" # replace with actual model
31
  DATASET_TEMPLATE = "EYEDOL/AGRILLAVA-image-text{}"
32
  NUM_DATASETS = 1
33
  BATCH_SIZE = 16
34
  TOP_K_DEFAULT = 3
35
 
36
- # Device
37
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
 
39
  # -------------------------
@@ -80,15 +77,24 @@ def retrieve_top_k_texts(image: Image.Image, k=TOP_K_DEFAULT):
80
  return results
81
 
82
  # -------------------------
83
- # Llava response
84
  # -------------------------
85
- @lru_cache(maxsize=1)
 
86
  def load_llava_model():
87
- from transformers import AutoModelForCausalLM, AutoTokenizer
88
- tokenizer = AutoTokenizer.from_pretrained(LLAVA_MODEL_ID)
89
- model = AutoModelForCausalLM.from_pretrained(LLAVA_MODEL_ID).to(device)
90
- model.eval()
91
- return tokenizer, model
 
 
 
 
 
 
 
 
92
 
93
  def llava_answer(image: Image.Image, retrieved_texts: List[str], question: str, max_tokens=256):
94
  tokenizer, model = load_llava_model()
@@ -110,7 +116,10 @@ def gradio_pipeline(image: Image.Image, question: str, k: int = TOP_K_DEFAULT):
110
  return None, "Please provide both image and question."
111
 
112
  retrieved_texts = retrieve_top_k_texts(image, k=int(k))
113
- response = llava_answer(image, retrieved_texts, question)
 
 
 
114
  return image, response
115
 
116
  with gr.Blocks(title="Agri Image + Question β†’ Llava Response") as demo:
@@ -126,4 +135,4 @@ with gr.Blocks(title="Agri Image + Question β†’ Llava Response") as demo:
126
  run_btn.click(fn=gradio_pipeline, inputs=[img_in, question_input, k_slider], outputs=[out_img, txt_out])
127
 
128
  if __name__ == "__main__":
129
- demo.launch(server_name="0.0.0.0", share=False)
 
1
  """
2
+ Gradio Space app (app.py) β€” SigLip Image + Question β†’ Llava Response (Improved)
3
 
4
+ Pipeline and improvements:
5
+ 1. User uploads an agriculture image and asks a question.
6
+ 2. SigLip model retrieves top-k relevant texts.
7
+ 3. Llava model generates a response using retrieved texts, image, and question.
 
 
8
 
9
+ Improvements implemented to handle the Tokenizer/Model errors:
10
+ - Lazy-load Llava model and tokenizer only when first required, reducing startup errors and memory usage.
11
+ - Added exception handling for tokenizer/model loading failures (common with incompatible or custom Llava models).
12
+ - Added clear error messages to guide installing correct dependencies or using compatible model versions.
13
  """
14
 
15
  import os
 
24
  from transformers import AutoProcessor, AutoModel
25
  from tqdm import tqdm
26
 
 
 
 
27
  SIGLIP_MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned"
28
+ LLAVA_MODEL_ID = "your-llava-model-hf-id" # replace with actual model
29
  DATASET_TEMPLATE = "EYEDOL/AGRILLAVA-image-text{}"
30
  NUM_DATASETS = 1
31
  BATCH_SIZE = 16
32
  TOP_K_DEFAULT = 3
33
 
 
34
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
 
36
  # -------------------------
 
77
  return results
78
 
79
  # -------------------------
80
+ # Lazy-load Llava model with error handling
81
  # -------------------------
82
+ llava_model_cache = {}
83
+
84
  def load_llava_model():
85
+ if 'model' in llava_model_cache and 'tokenizer' in llava_model_cache:
86
+ return llava_model_cache['tokenizer'], llava_model_cache['model']
87
+
88
+ try:
89
+ from transformers import AutoModelForCausalLM, AutoTokenizer
90
+ tokenizer = AutoTokenizer.from_pretrained(LLAVA_MODEL_ID)
91
+ model = AutoModelForCausalLM.from_pretrained(LLAVA_MODEL_ID).to(device)
92
+ model.eval()
93
+ llava_model_cache['tokenizer'] = tokenizer
94
+ llava_model_cache['model'] = model
95
+ return tokenizer, model
96
+ except Exception as e:
97
+ raise RuntimeError(f"Failed to load Llava model/tokenizer: {e}. Ensure LLAVA_MODEL_ID is correct and compatible with transformers.")
98
 
99
  def llava_answer(image: Image.Image, retrieved_texts: List[str], question: str, max_tokens=256):
100
  tokenizer, model = load_llava_model()
 
116
  return None, "Please provide both image and question."
117
 
118
  retrieved_texts = retrieve_top_k_texts(image, k=int(k))
119
+ try:
120
+ response = llava_answer(image, retrieved_texts, question)
121
+ except RuntimeError as e:
122
+ response = str(e)
123
  return image, response
124
 
125
  with gr.Blocks(title="Agri Image + Question β†’ Llava Response") as demo:
 
135
  run_btn.click(fn=gradio_pipeline, inputs=[img_in, question_input, k_slider], outputs=[out_img, txt_out])
136
 
137
  if __name__ == "__main__":
138
+ demo.launch(server_name="0.0.0.0", share=False)