EYEDOL commited on
Commit
9495b41
Β·
verified Β·
1 Parent(s): 2e9bf08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -17
app.py CHANGED
@@ -1,10 +1,9 @@
1
  """
2
- Gradio Space app (app.py) β€” Preloaded SigLip + Llava pipeline for instant response
3
-
4
  Pipeline:
5
- 1. At startup: load SigLip processor & model, compute all text embeddings.
6
- 2. At startup: load Llava tokenizer & model.
7
- 3. User uploads an image and asks a question β†’ pipeline uses preloaded resources for instant retrieval and response.
8
  """
9
 
10
  import os
@@ -13,13 +12,17 @@ from typing import List, Tuple
13
  import gradio as gr
14
  import torch
15
  import torch.nn.functional as F
16
- from datasets import load_dataset, concatenate_datasets
17
  from PIL import Image
18
- from transformers import AutoProcessor, AutoModel, AutoTokenizer, AutoModelForCausalLM
19
- from tqdm import tqdm
 
 
 
 
20
 
21
  SIGLIP_MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned"
22
- LLAVA_MODEL_ID = "llava-hf/llava-1.5-7b-hf" # replace with actual model
23
  DATASET_TEMPLATE = "EYEDOL/AGRILLAVA-image-text{}"
24
  NUM_DATASETS = 1
25
  BATCH_SIZE = 16
@@ -28,7 +31,7 @@ TOP_K_DEFAULT = 3
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
 
30
  # -------------------------
31
- # Startup: load all datasets and compute text embeddings
32
  # -------------------------
33
  print("⏳ Loading datasets and computing SigLip text embeddings...")
34
  texts_all = []
@@ -41,7 +44,7 @@ siglip_model = AutoModel.from_pretrained(SIGLIP_MODEL_ID).to(device)
41
  siglip_model.eval()
42
 
43
  text_embeds_all = []
44
- for i in tqdm(range(0, len(texts_all), BATCH_SIZE), desc="Encoding texts"):
45
  batch_texts = texts_all[i:i+BATCH_SIZE]
46
  inputs = siglip_processor(text=batch_texts, padding=True, truncation=True, return_tensors="pt").to(device)
47
  with torch.no_grad():
@@ -57,14 +60,14 @@ print(f"βœ… Finished encoding {len(texts_all)} texts. Shape: {text_embeds_all.sh
57
  # -------------------------
58
  # Startup: load Llava model & tokenizer
59
  # -------------------------
60
- print("⏳ Loading Llava model and tokenizer...")
61
  llava_tokenizer = AutoTokenizer.from_pretrained(LLAVA_MODEL_ID, use_fast=False)
62
- llava_model = AutoModelForCausalLM.from_pretrained(LLAVA_MODEL_ID).to(device)
63
  llava_model.eval()
64
  print("βœ… Llava model loaded.")
65
 
66
  # -------------------------
67
- # SigLip retrieval
68
  # -------------------------
69
 
70
  def retrieve_top_k_texts(image: Image.Image, k=TOP_K_DEFAULT):
@@ -79,7 +82,7 @@ def retrieve_top_k_texts(image: Image.Image, k=TOP_K_DEFAULT):
79
  return results
80
 
81
  # -------------------------
82
- # Llava answer
83
  # -------------------------
84
 
85
  def llava_answer(image: Image.Image, retrieved_texts: List[str], question: str, max_tokens=256):
@@ -93,7 +96,7 @@ def llava_answer(image: Image.Image, retrieved_texts: List[str], question: str,
93
  return response
94
 
95
  # -------------------------
96
- # Gradio interface
97
  # -------------------------
98
 
99
  def gradio_pipeline(image: Image.Image, question: str, k: int = TOP_K_DEFAULT):
@@ -104,6 +107,9 @@ def gradio_pipeline(image: Image.Image, question: str, k: int = TOP_K_DEFAULT):
104
  response = llava_answer(image, retrieved_texts, question)
105
  return image, response
106
 
 
 
 
107
  with gr.Blocks(title="Agri Image + Question β†’ Llava Response") as demo:
108
  gr.Markdown("# Agri Image Question Answering\nUpload an agriculture image, ask a question, and get context-aware crop suggestions.")
109
  with gr.Row():
@@ -117,4 +123,4 @@ with gr.Blocks(title="Agri Image + Question β†’ Llava Response") as demo:
117
  run_btn.click(fn=gradio_pipeline, inputs=[img_in, question_input, k_slider], outputs=[out_img, txt_out])
118
 
119
  if __name__ == "__main__":
120
- demo.launch(server_name="0.0.0.0", share=False)
 
1
  """
2
+ Gradio Space app: Preloaded SigLip + Llava pipeline for instant user response.
 
3
  Pipeline:
4
+ 1. Startup: load SigLip processor, model, compute all text embeddings.
5
+ 2. Startup: load Llava tokenizer & LlavaForCausalLM model.
6
+ 3. User uploads image + asks question β†’ instant retrieval + Llava response.
7
  """
8
 
9
  import os
 
12
  import gradio as gr
13
  import torch
14
  import torch.nn.functional as F
15
+ from datasets import load_dataset
16
  from PIL import Image
17
+ from transformers import AutoProcessor
18
+
19
+ # Install llava repo if not already installed:
20
+ # pip install git+https://github.com/haotian-liu/LLaVA.git
21
+ from llava.model import LlavaForCausalLM
22
+ from transformers import AutoTokenizer
23
 
24
  SIGLIP_MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned"
25
+ LLAVA_MODEL_ID = "llava-hf/llava-1.5-7b-hf" # replace with your actual model repo
26
  DATASET_TEMPLATE = "EYEDOL/AGRILLAVA-image-text{}"
27
  NUM_DATASETS = 1
28
  BATCH_SIZE = 16
 
31
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
 
33
  # -------------------------
34
+ # Startup: load datasets and compute SigLip text embeddings
35
  # -------------------------
36
  print("⏳ Loading datasets and computing SigLip text embeddings...")
37
  texts_all = []
 
44
  siglip_model.eval()
45
 
46
  text_embeds_all = []
47
+ for i in range(0, len(texts_all), BATCH_SIZE):
48
  batch_texts = texts_all[i:i+BATCH_SIZE]
49
  inputs = siglip_processor(text=batch_texts, padding=True, truncation=True, return_tensors="pt").to(device)
50
  with torch.no_grad():
 
60
  # -------------------------
61
  # Startup: load Llava model & tokenizer
62
  # -------------------------
63
+ print("⏳ Loading Llava tokenizer and LlavaForCausalLM model...")
64
  llava_tokenizer = AutoTokenizer.from_pretrained(LLAVA_MODEL_ID, use_fast=False)
65
+ llava_model = LlavaForCausalLM.from_pretrained(LLAVA_MODEL_ID).to(device)
66
  llava_model.eval()
67
  print("βœ… Llava model loaded.")
68
 
69
  # -------------------------
70
+ # SigLip retrieval function
71
  # -------------------------
72
 
73
  def retrieve_top_k_texts(image: Image.Image, k=TOP_K_DEFAULT):
 
82
  return results
83
 
84
  # -------------------------
85
+ # Llava answer function
86
  # -------------------------
87
 
88
  def llava_answer(image: Image.Image, retrieved_texts: List[str], question: str, max_tokens=256):
 
96
  return response
97
 
98
  # -------------------------
99
+ # Gradio interface pipeline
100
  # -------------------------
101
 
102
  def gradio_pipeline(image: Image.Image, question: str, k: int = TOP_K_DEFAULT):
 
107
  response = llava_answer(image, retrieved_texts, question)
108
  return image, response
109
 
110
+ # -------------------------
111
+ # Gradio Blocks
112
+ # -------------------------
113
  with gr.Blocks(title="Agri Image + Question β†’ Llava Response") as demo:
114
  gr.Markdown("# Agri Image Question Answering\nUpload an agriculture image, ask a question, and get context-aware crop suggestions.")
115
  with gr.Row():
 
123
  run_btn.click(fn=gradio_pipeline, inputs=[img_in, question_input, k_slider], outputs=[out_img, txt_out])
124
 
125
  if __name__ == "__main__":
126
+ demo.launch(server_name="0.0.0.0", share=False)