# app.py import os from typing import Tuple, List import gradio as gr import spaces # <- habilita ZeroGPU decorators import torch from PIL import Image from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration MODEL_ID = os.environ.get("MODEL_ID", "BSC-LT/salamandra-7b-vision") DTYPE = torch.float16 # half precision para H200/A100 DEVICE = "cuda" # ZeroGPU asigna gpu por llamada en @spaces.GPU # Carga perezosa: sólo la primera vez que se invoca en GPU _model = None _processor = None def _lazy_load(): global _model, _processor if _model is None or _processor is None: _processor = AutoProcessor.from_pretrained(MODEL_ID) _model = LlavaOnevisionForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype=DTYPE, low_cpu_mem_usage=True, trust_remote_code=True, device_map=None, # movemos explícitamente a cuda con @spaces.GPU use_safetensors=True, ) return _model, _processor @spaces.GPU # <- asegura que la función se ejecute con GPU asignada def describe(image: Image.Image, prompt_text: str, max_new_tokens: int, temperature: float) -> str: """ Devuelve una descripción a partir de imagen + prompt en texto. """ model, processor = _lazy_load() # Formateo estilo chat template recomendado por el model card conversation = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": prompt_text or "Descriu la imatge amb el màxim detall possible."}, ], } ] prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) # A GPU justo antes de inferir (ZeroGPU) model = model.to(DEVICE) inputs = processor(images=image, text=prompt, return_tensors="pt").to(DEVICE, DTYPE) with torch.inference_mode(): output = model.generate( **inputs, max_new_tokens=int(max_new_tokens), temperature=float(temperature), ) text = processor.decode(output[0], skip_special_tokens=True) return text.strip() with gr.Blocks(title="Salamandra Vision 7B (ZeroGPU)") as demo: gr.Markdown("# Salamandra-Vision 7B · ZeroGPU\nEnvía una imagen y un texto/prompta, recibe una descripción.") with gr.Row(): with gr.Column(): in_img = gr.Image(label="Imagen", type="pil") in_txt = gr.Textbox( label="Texto/prompta", value="Describe la imagen con el mayor detalle posible (en catalán o español)." ) max_new = gr.Slider(16, 1024, value=256, step=16, label="max_new_tokens") temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature") btn = gr.Button("Generar", variant="primary") with gr.Column(): out = gr.Textbox(label="Descripción", lines=18) btn.click(describe, inputs=[in_img, in_txt, max_new, temp], outputs=out, api_name="describe") # Cola de Gradio: útil para ZeroGPU y picos de demanda demo.queue(concurrency_count=1, max_size=16).launch()