File size: 9,934 Bytes
a48863e
30a56dc
94b81b9
 
 
8c695f8
1278acb
 
94b81b9
a48863e
1278acb
a48863e
8c695f8
 
9495b41
a48863e
 
 
 
 
 
8c695f8
463c8c8
 
30a56dc
463c8c8
a48863e
463c8c8
43590b6
f30de13
43590b6
a48863e
43590b6
8c695f8
a48863e
 
 
94b81b9
8c695f8
a48863e
463c8c8
a48863e
8c695f8
 
a48863e
8c695f8
94b81b9
1278acb
0905af1
 
 
 
 
 
 
 
a48863e
94b81b9
a48863e
463c8c8
 
0905af1
 
 
94b81b9
0905af1
94b81b9
 
1278acb
94b81b9
 
0905af1
 
94b81b9
0905af1
94b81b9
1278acb
a48863e
1278acb
 
 
 
94b81b9
1278acb
 
94b81b9
1278acb
 
 
 
 
 
 
 
 
94b81b9
 
a48863e
1278acb
 
a48863e
94b81b9
 
 
1278acb
 
 
 
 
 
 
 
 
 
 
 
94b81b9
 
a48863e
94b81b9
 
a48863e
94b81b9
a48863e
94b81b9
 
a48863e
94b81b9
 
a48863e
94b81b9
a48863e
94b81b9
 
a48863e
94b81b9
 
 
 
a48863e
94b81b9
a48863e
94b81b9
 
 
 
 
 
 
 
a48863e
94b81b9
 
 
 
 
 
 
8c695f8
 
a48863e
8c695f8
1278acb
463c8c8
30a56dc
0905af1
30a56dc
 
8c695f8
 
43590b6
8c695f8
 
1278acb
43590b6
1278acb
94b81b9
1278acb
 
94b81b9
1278acb
43590b6
94b81b9
 
 
 
 
 
 
 
 
 
 
 
a48863e
94b81b9
 
a48863e
 
94b81b9
 
43590b6
8c695f8
a48863e
8c695f8
43590b6
 
463c8c8
 
1278acb
 
 
 
94b81b9
463c8c8
43590b6
94b81b9
1278acb
94b81b9
a48863e
 
1278acb
30a56dc
8c695f8
 
43590b6
 
1278acb
43590b6
 
30a56dc
 
463c8c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
# app.py β€” Robust CPU-friendly SigLip -> (Llava local | trust_remote_code | HF router) pipeline
import os
# Force CPU before importing torch/transformers if you want CPU-only
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
os.environ.setdefault("TRANSFORMERS_NO_ADVISORY_WARNINGS", "1")

import sys
import traceback
import json
from typing import List, Optional

import requests
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import (
    AutoProcessor,
    AutoModel,
    AutoTokenizer,
    AutoModelForCausalLM,
)
from PIL import Image
import gradio as gr
from tqdm import tqdm

# -------------------------
# Config - update these IDs as needed
# -------------------------
SIGLIP_MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned"
LLAVA_MODEL_ID = "liuhaotian/llava-v1.6-vicuna-7b"  
DATASET_TEMPLATE = "EYEDOL/AGRILLAVA-image-text{}"
NUM_DATASETS = 1         # set to 15 if you want all datasets (startup memory/time increases)
BATCH_SIZE = 16
TOP_K_DEFAULT = 3

# Hugging Face router endpoint (new inference endpoint)
HF_API_URL = "https://router.huggingface.co/hf-inference"
HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", None)

# Device - CPU only
device = torch.device("cpu")
print("Running on device:", device)

# -------------------------
# Load dataset and SigLip model & precompute text embeddings at startup
# -------------------------
print("Loading datasets and computing SigLip text embeddings (startup)...")
texts_all: List[str] = []
for i in range(1, NUM_DATASETS + 1):
    ds = load_dataset(DATASET_TEMPLATE.format(i), split="train")
    texts_all.extend(ds["text"])

siglip_processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_ID)
siglip_model = AutoModel.from_pretrained(SIGLIP_MODEL_ID).to(device)
siglip_model.eval()

# Precompute text embeddings (on CPU)
text_embeds_parts = []
for i in tqdm(range(0, len(texts_all), BATCH_SIZE), desc="Encoding texts (CPU)"):
    batch_texts = texts_all[i : i + BATCH_SIZE]
    inputs = siglip_processor(text=batch_texts, padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
        text_embeds = siglip_model.get_text_features(**inputs)
        text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
        text_embeds_parts.append(text_embeds.cpu())
    del inputs, text_embeds
if text_embeds_parts:
    text_embeds_all = torch.cat(text_embeds_parts, dim=0)
else:
    text_embeds_all = torch.empty((0, 0))
print(f"Encoded {len(texts_all)} texts. Embeddings shape: {text_embeds_all.shape}")

# -------------------------
# Llava loading: try local package -> trust_remote_code -> HF Inference API (if token provided)
# -------------------------
llava_tokenizer: Optional[AutoTokenizer] = None
llava_model = None
llava_mode: Optional[str] = None  # 'local', 'trust_remote_code', 'hf_api', or None
load_errors = []

# Attempt 1: local llava package (preferred)
try:
    # this import requires the LLaVA repo to be installed in the environment (requirements.txt)
    from llava.model import LlavaForCausalLM  # type: ignore

    print("Loading LlavaForCausalLM from installed 'llava' package (CPU)...")
    llava_tokenizer = AutoTokenizer.from_pretrained(LLAVA_MODEL_ID, use_fast=False)
    llava_model = LlavaForCausalLM.from_pretrained(
        LLAVA_MODEL_ID,
        device_map={"": "cpu"},
        torch_dtype=torch.float32,
        low_cpu_mem_usage=True,
    )
    llava_model.to(device)
    llava_model.eval()
    llava_mode = "local"
    print("βœ… Llava loaded from installed package.")
except Exception:
    tb_local = traceback.format_exc()
    load_errors.append(("local_llava_import", tb_local))
    print("Local llava import failed β€” will try trust_remote_code fallback. See logs for details.")

# Attempt 2: trust_remote_code fallback
if llava_mode is None:
    try:
        print("Attempting AutoModelForCausalLM.from_pretrained(..., trust_remote_code=True) (CPU)...")
        llava_tokenizer = AutoTokenizer.from_pretrained(LLAVA_MODEL_ID, use_fast=False)
        llava_model = AutoModelForCausalLM.from_pretrained(
            LLAVA_MODEL_ID,
            trust_remote_code=True,
            device_map={"": "cpu"},
            torch_dtype=torch.float32,
            low_cpu_mem_usage=True,
        )
        llava_model.to(device)
        llava_model.eval()
        llava_mode = "trust_remote_code"
        print("βœ… Llava loaded via trust_remote_code fallback.")
    except Exception:
        tb_trust = traceback.format_exc()
        load_errors.append(("fallback_trust_remote_code", tb_trust))
        print("trust_remote_code fallback failed β€” will try HF router if token provided.")

# Attempt 3: Hugging Face router Inference API fallback (requires HUGGINGFACE_TOKEN)
if llava_mode is None and HUGGINGFACE_TOKEN:
    llava_mode = "hf_api"
    print("No usable local model found. Will use Hugging Face router Inference API for generation (HUGGINGFACE_TOKEN detected).")

if llava_mode is None:
    print("WARNING: No Llava model available and no HUGGINGFACE_TOKEN supplied. Generation will return an actionable error.")
    for name, tb in load_errors:
        print(f"--- {name} traceback ---\n{tb}")

# -------------------------
# Helper: call Hugging Face router inference API
# -------------------------
def call_hf_inference_api(prompt: str, max_new_tokens: int = 256, temperature: float = 0.0):
    if not HUGGINGFACE_TOKEN:
        raise RuntimeError("HUGGINGFACE_TOKEN not set; cannot call Hugging Face Inference API.")
    headers = {"Authorization": f"Bearer {HUGGINGFACE_TOKEN}", "Content-Type": "application/json"}
    payload = {
        "model": LLAVA_MODEL_ID,
        "inputs": prompt,
        "parameters": {"max_new_tokens": max_new_tokens, "temperature": temperature},
        "options": {"wait_for_model": True},
    }
    resp = requests.post(HF_API_URL, headers=headers, json=payload, timeout=300)
    if resp.status_code != 200:
        raise RuntimeError(f"HF Inference API error {resp.status_code}: {resp.text}")
    data = resp.json()
    # handle common response shapes
    if isinstance(data, list) and data and isinstance(data[0], dict) and "generated_text" in data[0]:
        return data[0]["generated_text"]
    if isinstance(data, dict) and "generated_text" in data:
        return data["generated_text"]
    if isinstance(data, str):
        return data
    return json.dumps(data)

# -------------------------
# Retrieval & generation
# -------------------------
def retrieve_top_k_texts(image: Image.Image, k: int = TOP_K_DEFAULT):
    inputs = siglip_processor(images=image, return_tensors="pt")
    with torch.no_grad():
        img_embed = siglip_model.get_image_features(**inputs)
        img_embed = img_embed / img_embed.norm(p=2, dim=-1, keepdim=True)

    sims = F.cosine_similarity(img_embed.cpu(), text_embeds_all)
    topk = torch.topk(sims, k)
    results = [(texts_all[idx.item()], float(score)) for idx, score in zip(topk.indices, topk.values)]
    return results

def llava_answer(image: Image.Image, retrieved_texts, question: str, max_tokens: int = 256):
    context_text = "\n".join([f"Retrieved Text: {t}" for t, _ in retrieved_texts])
    prompt = (
        "You are an agricultural assistant. Use the provided retrieved texts to answer concisely.\n\n"
        f"Retrieved texts:\n{context_text}\n\n"
        f"User question: {question}\n\n"
        "Provide a concise, actionable answer and crop suggestions when applicable."
    )

    if llava_mode in ("local", "trust_remote_code"):
        inputs = llava_tokenizer(prompt, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            output_ids = llava_model.generate(**inputs, max_new_tokens=max_tokens)
        resp = llava_tokenizer.decode(output_ids[0], skip_special_tokens=True)
        return resp
    elif llava_mode == "hf_api":
        return call_hf_inference_api(prompt, max_new_tokens=max_tokens)
    else:
        err = (
            "No Llava model is available for generation.\n\n"
            "Fix options:\n"
            "1) Install the LLaVA repo in requirements.txt and rebuild the Space:\n"
            "   git+https://github.com/haotian-liu/LLaVA.git@main\n"
            "2) Or add a valid Hugging Face API token as HUGGINGFACE_TOKEN in Space secrets to use the router.\n\n"
            "Check Space logs for detailed tracebacks printed at startup."
        )
        return err

# -------------------------
# Gradio app
# -------------------------
def gradio_pipeline(image: Image.Image, question: str, k: int = TOP_K_DEFAULT):
    if image is None or not question:
        return None, "Please provide both an image and a question."
    retrieved = retrieve_top_k_texts(image, k=int(k))
    try:
        answer = llava_answer(image, retrieved, question)
    except Exception as e:
        tb = traceback.format_exc()
        answer = f"Error during generation: {e}\n\nTraceback:\n{tb}"
    return image, answer

with gr.Blocks(title="Agri Image + Question β†’ Llava Response (robust)") as demo:
    gr.Markdown(
        "## Agri Image QA\n\nThis app preloads SigLip embeddings at startup. "
        "Generation uses a local Llava model if available, otherwise the Hugging Face router Inference API "
        "(requires HUGGINGFACE_TOKEN secret in Space settings)."
    )
    with gr.Row():
        img_in = gr.Image(type="pil")
        out_img = gr.Image(type="pil", label="Image")
    question_input = gr.Textbox(label="Question about the image", lines=2)
    k_slider = gr.Slider(minimum=1, maximum=10, step=1, value=TOP_K_DEFAULT, label="Top-k retrieval")
    txt_out = gr.Textbox(label="Llava Response", lines=12)
    run_btn = gr.Button("Generate Answer")
    run_btn.click(fn=gradio_pipeline, inputs=[img_in, question_input, k_slider], outputs=[out_img, txt_out])

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", share=False)