image_qa / src /streamlit_app.py
ClaudiaRichard's picture
update streamlit.py
44e69d2 verified
import torch
import streamlit as st
from PIL import Image
from transformers import MllamaForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
import os
# --- Setup HF cache to writable directory (for Spaces) ---
os.environ["HF_HOME"] = "/tmp/hf-cache"
HF_TOKEN = os.environ["HF_TOKEN"]
# Set up model and processor with 4-bit quantization
@st.cache_resource(show_spinner="Loading LLaVA model...") # Cache so it's downloaded only once
def load_model():
model_id = "meta-llama/Llama-3.2-11B-Vision"
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
model = MllamaForConditionalGeneration.from_pretrained(
model_id,
quantization_config=quant_config,
torch_dtype=torch.float16,
device_map="auto",
token=HF_TOKEN,
cache_dir="/tmp/hf-cache"
)
processor = AutoProcessor.from_pretrained(model_id, token=HF_TOKEN, cache_dir="/tmp/hf-cache")
return model, processor
# Load model and processor
model, processor = load_model()
# --- Streamlit UI ---
st.title("🧠 LLaVA-3.2 Vision QA (4-bit)")
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
prompt = st.text_input("Enter your prompt", value="What do I want?")
if uploaded_file and prompt:
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded Image", use_column_width=True)
full_prompt = f"<|image|><|begin_of_text|>{prompt}"
inputs = processor(image, full_prompt, return_tensors="pt").to(model.device)
st.write("πŸŒ€ Generating response...")
output = model.generate(**inputs, max_new_tokens=50)
response = processor.decode(output[0], skip_special_tokens=True)
st.success("βœ… Response:")
st.markdown(response)