import torch import streamlit as st from PIL import Image from transformers import MllamaForConditionalGeneration, AutoProcessor, BitsAndBytesConfig import os # --- Setup HF cache to writable directory (for Spaces) --- os.environ["HF_HOME"] = "/tmp/hf-cache" HF_TOKEN = os.environ["HF_TOKEN"] # Set up model and processor with 4-bit quantization @st.cache_resource(show_spinner="Loading LLaVA model...") # Cache so it's downloaded only once def load_model(): model_id = "meta-llama/Llama-3.2-11B-Vision" quant_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) model = MllamaForConditionalGeneration.from_pretrained( model_id, quantization_config=quant_config, torch_dtype=torch.float16, device_map="auto", token=HF_TOKEN, cache_dir="/tmp/hf-cache" ) processor = AutoProcessor.from_pretrained(model_id, token=HF_TOKEN, cache_dir="/tmp/hf-cache") return model, processor # Load model and processor model, processor = load_model() # --- Streamlit UI --- st.title("🧠 LLaVA-3.2 Vision QA (4-bit)") uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) prompt = st.text_input("Enter your prompt", value="What do I want?") if uploaded_file and prompt: image = Image.open(uploaded_file).convert("RGB") st.image(image, caption="Uploaded Image", use_column_width=True) full_prompt = f"<|image|><|begin_of_text|>{prompt}" inputs = processor(image, full_prompt, return_tensors="pt").to(model.device) st.write("🌀 Generating response...") output = model.generate(**inputs, max_new_tokens=50) response = processor.decode(output[0], skip_special_tokens=True) st.success("✅ Response:") st.markdown(response)