Update app.py
Browse files
app.py
CHANGED
|
@@ -1,14 +1,15 @@
|
|
| 1 |
"""
|
| 2 |
-
Gradio Space app (app.py) β SigLip Image + Question β Llava Response
|
| 3 |
|
| 4 |
-
Pipeline:
|
| 5 |
-
1. User uploads an agriculture image.
|
| 6 |
-
2.
|
| 7 |
-
3.
|
| 8 |
-
4. The retrieved text, original image, and user's question are sent to a Llava model.
|
| 9 |
-
5. Llava generates a context-aware response with crop suggestions or explanations.
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
import os
|
|
@@ -23,17 +24,13 @@ from PIL import Image
|
|
| 23 |
from transformers import AutoProcessor, AutoModel
|
| 24 |
from tqdm import tqdm
|
| 25 |
|
| 26 |
-
# -------------------------
|
| 27 |
-
# Config
|
| 28 |
-
# -------------------------
|
| 29 |
SIGLIP_MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned"
|
| 30 |
-
LLAVA_MODEL_ID = "
|
| 31 |
DATASET_TEMPLATE = "EYEDOL/AGRILLAVA-image-text{}"
|
| 32 |
NUM_DATASETS = 1
|
| 33 |
BATCH_SIZE = 16
|
| 34 |
TOP_K_DEFAULT = 3
|
| 35 |
|
| 36 |
-
# Device
|
| 37 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 38 |
|
| 39 |
# -------------------------
|
|
@@ -80,15 +77,24 @@ def retrieve_top_k_texts(image: Image.Image, k=TOP_K_DEFAULT):
|
|
| 80 |
return results
|
| 81 |
|
| 82 |
# -------------------------
|
| 83 |
-
# Llava
|
| 84 |
# -------------------------
|
| 85 |
-
|
|
|
|
| 86 |
def load_llava_model():
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
def llava_answer(image: Image.Image, retrieved_texts: List[str], question: str, max_tokens=256):
|
| 94 |
tokenizer, model = load_llava_model()
|
|
@@ -110,7 +116,10 @@ def gradio_pipeline(image: Image.Image, question: str, k: int = TOP_K_DEFAULT):
|
|
| 110 |
return None, "Please provide both image and question."
|
| 111 |
|
| 112 |
retrieved_texts = retrieve_top_k_texts(image, k=int(k))
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
| 114 |
return image, response
|
| 115 |
|
| 116 |
with gr.Blocks(title="Agri Image + Question β Llava Response") as demo:
|
|
@@ -126,4 +135,4 @@ with gr.Blocks(title="Agri Image + Question β Llava Response") as demo:
|
|
| 126 |
run_btn.click(fn=gradio_pipeline, inputs=[img_in, question_input, k_slider], outputs=[out_img, txt_out])
|
| 127 |
|
| 128 |
if __name__ == "__main__":
|
| 129 |
-
demo.launch(server_name="0.0.0.0", share=False)
|
|
|
|
| 1 |
"""
|
| 2 |
+
Gradio Space app (app.py) β SigLip Image + Question β Llava Response (Improved)
|
| 3 |
|
| 4 |
+
Pipeline and improvements:
|
| 5 |
+
1. User uploads an agriculture image and asks a question.
|
| 6 |
+
2. SigLip model retrieves top-k relevant texts.
|
| 7 |
+
3. Llava model generates a response using retrieved texts, image, and question.
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
Improvements implemented to handle the Tokenizer/Model errors:
|
| 10 |
+
- Lazy-load Llava model and tokenizer only when first required, reducing startup errors and memory usage.
|
| 11 |
+
- Added exception handling for tokenizer/model loading failures (common with incompatible or custom Llava models).
|
| 12 |
+
- Added clear error messages to guide installing correct dependencies or using compatible model versions.
|
| 13 |
"""
|
| 14 |
|
| 15 |
import os
|
|
|
|
| 24 |
from transformers import AutoProcessor, AutoModel
|
| 25 |
from tqdm import tqdm
|
| 26 |
|
|
|
|
|
|
|
|
|
|
| 27 |
SIGLIP_MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned"
|
| 28 |
+
LLAVA_MODEL_ID = "your-llava-model-hf-id" # replace with actual model
|
| 29 |
DATASET_TEMPLATE = "EYEDOL/AGRILLAVA-image-text{}"
|
| 30 |
NUM_DATASETS = 1
|
| 31 |
BATCH_SIZE = 16
|
| 32 |
TOP_K_DEFAULT = 3
|
| 33 |
|
|
|
|
| 34 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 35 |
|
| 36 |
# -------------------------
|
|
|
|
| 77 |
return results
|
| 78 |
|
| 79 |
# -------------------------
|
| 80 |
+
# Lazy-load Llava model with error handling
|
| 81 |
# -------------------------
|
| 82 |
+
llava_model_cache = {}
|
| 83 |
+
|
| 84 |
def load_llava_model():
|
| 85 |
+
if 'model' in llava_model_cache and 'tokenizer' in llava_model_cache:
|
| 86 |
+
return llava_model_cache['tokenizer'], llava_model_cache['model']
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 90 |
+
tokenizer = AutoTokenizer.from_pretrained(LLAVA_MODEL_ID)
|
| 91 |
+
model = AutoModelForCausalLM.from_pretrained(LLAVA_MODEL_ID).to(device)
|
| 92 |
+
model.eval()
|
| 93 |
+
llava_model_cache['tokenizer'] = tokenizer
|
| 94 |
+
llava_model_cache['model'] = model
|
| 95 |
+
return tokenizer, model
|
| 96 |
+
except Exception as e:
|
| 97 |
+
raise RuntimeError(f"Failed to load Llava model/tokenizer: {e}. Ensure LLAVA_MODEL_ID is correct and compatible with transformers.")
|
| 98 |
|
| 99 |
def llava_answer(image: Image.Image, retrieved_texts: List[str], question: str, max_tokens=256):
|
| 100 |
tokenizer, model = load_llava_model()
|
|
|
|
| 116 |
return None, "Please provide both image and question."
|
| 117 |
|
| 118 |
retrieved_texts = retrieve_top_k_texts(image, k=int(k))
|
| 119 |
+
try:
|
| 120 |
+
response = llava_answer(image, retrieved_texts, question)
|
| 121 |
+
except RuntimeError as e:
|
| 122 |
+
response = str(e)
|
| 123 |
return image, response
|
| 124 |
|
| 125 |
with gr.Blocks(title="Agri Image + Question β Llava Response") as demo:
|
|
|
|
| 135 |
run_btn.click(fn=gradio_pipeline, inputs=[img_in, question_input, k_slider], outputs=[out_img, txt_out])
|
| 136 |
|
| 137 |
if __name__ == "__main__":
|
| 138 |
+
demo.launch(server_name="0.0.0.0", share=False)
|