File size: 9,934 Bytes
a48863e 30a56dc 94b81b9 8c695f8 1278acb 94b81b9 a48863e 1278acb a48863e 8c695f8 9495b41 a48863e 8c695f8 463c8c8 30a56dc 463c8c8 a48863e 463c8c8 43590b6 f30de13 43590b6 a48863e 43590b6 8c695f8 a48863e 94b81b9 8c695f8 a48863e 463c8c8 a48863e 8c695f8 a48863e 8c695f8 94b81b9 1278acb 0905af1 a48863e 94b81b9 a48863e 463c8c8 0905af1 94b81b9 0905af1 94b81b9 1278acb 94b81b9 0905af1 94b81b9 0905af1 94b81b9 1278acb a48863e 1278acb 94b81b9 1278acb 94b81b9 1278acb 94b81b9 a48863e 1278acb a48863e 94b81b9 1278acb 94b81b9 a48863e 94b81b9 a48863e 94b81b9 a48863e 94b81b9 a48863e 94b81b9 a48863e 94b81b9 a48863e 94b81b9 a48863e 94b81b9 a48863e 94b81b9 a48863e 94b81b9 a48863e 94b81b9 8c695f8 a48863e 8c695f8 1278acb 463c8c8 30a56dc 0905af1 30a56dc 8c695f8 43590b6 8c695f8 1278acb 43590b6 1278acb 94b81b9 1278acb 94b81b9 1278acb 43590b6 94b81b9 a48863e 94b81b9 a48863e 94b81b9 43590b6 8c695f8 a48863e 8c695f8 43590b6 463c8c8 1278acb 94b81b9 463c8c8 43590b6 94b81b9 1278acb 94b81b9 a48863e 1278acb 30a56dc 8c695f8 43590b6 1278acb 43590b6 30a56dc 463c8c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 |
# app.py β Robust CPU-friendly SigLip -> (Llava local | trust_remote_code | HF router) pipeline
import os
# Force CPU before importing torch/transformers if you want CPU-only
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
os.environ.setdefault("TRANSFORMERS_NO_ADVISORY_WARNINGS", "1")
import sys
import traceback
import json
from typing import List, Optional
import requests
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import (
AutoProcessor,
AutoModel,
AutoTokenizer,
AutoModelForCausalLM,
)
from PIL import Image
import gradio as gr
from tqdm import tqdm
# -------------------------
# Config - update these IDs as needed
# -------------------------
SIGLIP_MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned"
LLAVA_MODEL_ID = "liuhaotian/llava-v1.6-vicuna-7b"
DATASET_TEMPLATE = "EYEDOL/AGRILLAVA-image-text{}"
NUM_DATASETS = 1 # set to 15 if you want all datasets (startup memory/time increases)
BATCH_SIZE = 16
TOP_K_DEFAULT = 3
# Hugging Face router endpoint (new inference endpoint)
HF_API_URL = "https://router.huggingface.co/hf-inference"
HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", None)
# Device - CPU only
device = torch.device("cpu")
print("Running on device:", device)
# -------------------------
# Load dataset and SigLip model & precompute text embeddings at startup
# -------------------------
print("Loading datasets and computing SigLip text embeddings (startup)...")
texts_all: List[str] = []
for i in range(1, NUM_DATASETS + 1):
ds = load_dataset(DATASET_TEMPLATE.format(i), split="train")
texts_all.extend(ds["text"])
siglip_processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_ID)
siglip_model = AutoModel.from_pretrained(SIGLIP_MODEL_ID).to(device)
siglip_model.eval()
# Precompute text embeddings (on CPU)
text_embeds_parts = []
for i in tqdm(range(0, len(texts_all), BATCH_SIZE), desc="Encoding texts (CPU)"):
batch_texts = texts_all[i : i + BATCH_SIZE]
inputs = siglip_processor(text=batch_texts, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
text_embeds = siglip_model.get_text_features(**inputs)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds_parts.append(text_embeds.cpu())
del inputs, text_embeds
if text_embeds_parts:
text_embeds_all = torch.cat(text_embeds_parts, dim=0)
else:
text_embeds_all = torch.empty((0, 0))
print(f"Encoded {len(texts_all)} texts. Embeddings shape: {text_embeds_all.shape}")
# -------------------------
# Llava loading: try local package -> trust_remote_code -> HF Inference API (if token provided)
# -------------------------
llava_tokenizer: Optional[AutoTokenizer] = None
llava_model = None
llava_mode: Optional[str] = None # 'local', 'trust_remote_code', 'hf_api', or None
load_errors = []
# Attempt 1: local llava package (preferred)
try:
# this import requires the LLaVA repo to be installed in the environment (requirements.txt)
from llava.model import LlavaForCausalLM # type: ignore
print("Loading LlavaForCausalLM from installed 'llava' package (CPU)...")
llava_tokenizer = AutoTokenizer.from_pretrained(LLAVA_MODEL_ID, use_fast=False)
llava_model = LlavaForCausalLM.from_pretrained(
LLAVA_MODEL_ID,
device_map={"": "cpu"},
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
)
llava_model.to(device)
llava_model.eval()
llava_mode = "local"
print("β
Llava loaded from installed package.")
except Exception:
tb_local = traceback.format_exc()
load_errors.append(("local_llava_import", tb_local))
print("Local llava import failed β will try trust_remote_code fallback. See logs for details.")
# Attempt 2: trust_remote_code fallback
if llava_mode is None:
try:
print("Attempting AutoModelForCausalLM.from_pretrained(..., trust_remote_code=True) (CPU)...")
llava_tokenizer = AutoTokenizer.from_pretrained(LLAVA_MODEL_ID, use_fast=False)
llava_model = AutoModelForCausalLM.from_pretrained(
LLAVA_MODEL_ID,
trust_remote_code=True,
device_map={"": "cpu"},
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
)
llava_model.to(device)
llava_model.eval()
llava_mode = "trust_remote_code"
print("β
Llava loaded via trust_remote_code fallback.")
except Exception:
tb_trust = traceback.format_exc()
load_errors.append(("fallback_trust_remote_code", tb_trust))
print("trust_remote_code fallback failed β will try HF router if token provided.")
# Attempt 3: Hugging Face router Inference API fallback (requires HUGGINGFACE_TOKEN)
if llava_mode is None and HUGGINGFACE_TOKEN:
llava_mode = "hf_api"
print("No usable local model found. Will use Hugging Face router Inference API for generation (HUGGINGFACE_TOKEN detected).")
if llava_mode is None:
print("WARNING: No Llava model available and no HUGGINGFACE_TOKEN supplied. Generation will return an actionable error.")
for name, tb in load_errors:
print(f"--- {name} traceback ---\n{tb}")
# -------------------------
# Helper: call Hugging Face router inference API
# -------------------------
def call_hf_inference_api(prompt: str, max_new_tokens: int = 256, temperature: float = 0.0):
if not HUGGINGFACE_TOKEN:
raise RuntimeError("HUGGINGFACE_TOKEN not set; cannot call Hugging Face Inference API.")
headers = {"Authorization": f"Bearer {HUGGINGFACE_TOKEN}", "Content-Type": "application/json"}
payload = {
"model": LLAVA_MODEL_ID,
"inputs": prompt,
"parameters": {"max_new_tokens": max_new_tokens, "temperature": temperature},
"options": {"wait_for_model": True},
}
resp = requests.post(HF_API_URL, headers=headers, json=payload, timeout=300)
if resp.status_code != 200:
raise RuntimeError(f"HF Inference API error {resp.status_code}: {resp.text}")
data = resp.json()
# handle common response shapes
if isinstance(data, list) and data and isinstance(data[0], dict) and "generated_text" in data[0]:
return data[0]["generated_text"]
if isinstance(data, dict) and "generated_text" in data:
return data["generated_text"]
if isinstance(data, str):
return data
return json.dumps(data)
# -------------------------
# Retrieval & generation
# -------------------------
def retrieve_top_k_texts(image: Image.Image, k: int = TOP_K_DEFAULT):
inputs = siglip_processor(images=image, return_tensors="pt")
with torch.no_grad():
img_embed = siglip_model.get_image_features(**inputs)
img_embed = img_embed / img_embed.norm(p=2, dim=-1, keepdim=True)
sims = F.cosine_similarity(img_embed.cpu(), text_embeds_all)
topk = torch.topk(sims, k)
results = [(texts_all[idx.item()], float(score)) for idx, score in zip(topk.indices, topk.values)]
return results
def llava_answer(image: Image.Image, retrieved_texts, question: str, max_tokens: int = 256):
context_text = "\n".join([f"Retrieved Text: {t}" for t, _ in retrieved_texts])
prompt = (
"You are an agricultural assistant. Use the provided retrieved texts to answer concisely.\n\n"
f"Retrieved texts:\n{context_text}\n\n"
f"User question: {question}\n\n"
"Provide a concise, actionable answer and crop suggestions when applicable."
)
if llava_mode in ("local", "trust_remote_code"):
inputs = llava_tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
output_ids = llava_model.generate(**inputs, max_new_tokens=max_tokens)
resp = llava_tokenizer.decode(output_ids[0], skip_special_tokens=True)
return resp
elif llava_mode == "hf_api":
return call_hf_inference_api(prompt, max_new_tokens=max_tokens)
else:
err = (
"No Llava model is available for generation.\n\n"
"Fix options:\n"
"1) Install the LLaVA repo in requirements.txt and rebuild the Space:\n"
" git+https://github.com/haotian-liu/LLaVA.git@main\n"
"2) Or add a valid Hugging Face API token as HUGGINGFACE_TOKEN in Space secrets to use the router.\n\n"
"Check Space logs for detailed tracebacks printed at startup."
)
return err
# -------------------------
# Gradio app
# -------------------------
def gradio_pipeline(image: Image.Image, question: str, k: int = TOP_K_DEFAULT):
if image is None or not question:
return None, "Please provide both an image and a question."
retrieved = retrieve_top_k_texts(image, k=int(k))
try:
answer = llava_answer(image, retrieved, question)
except Exception as e:
tb = traceback.format_exc()
answer = f"Error during generation: {e}\n\nTraceback:\n{tb}"
return image, answer
with gr.Blocks(title="Agri Image + Question β Llava Response (robust)") as demo:
gr.Markdown(
"## Agri Image QA\n\nThis app preloads SigLip embeddings at startup. "
"Generation uses a local Llava model if available, otherwise the Hugging Face router Inference API "
"(requires HUGGINGFACE_TOKEN secret in Space settings)."
)
with gr.Row():
img_in = gr.Image(type="pil")
out_img = gr.Image(type="pil", label="Image")
question_input = gr.Textbox(label="Question about the image", lines=2)
k_slider = gr.Slider(minimum=1, maximum=10, step=1, value=TOP_K_DEFAULT, label="Top-k retrieval")
txt_out = gr.Textbox(label="Llava Response", lines=12)
run_btn = gr.Button("Generate Answer")
run_btn.click(fn=gradio_pipeline, inputs=[img_in, question_input, k_slider], outputs=[out_img, txt_out])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", share=False)
|