import gradio as gr import spaces import torch from PIL import Image from transformers import AutoModelForCausalLM, AutoTokenizer MID = "SVECTOR-CORPORATION/Fal-2-500M" IMAGE_TOKEN_INDEX = -200 tok = None model = None def load_model(): global tok, model if tok is None or model is None: print("Loading model...") tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True) # Determine device and dtype device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # Load model without device_map for CPU, or with proper device_map for CUDA if torch.cuda.is_available(): model = AutoModelForCausalLM.from_pretrained( MID, torch_dtype=dtype, device_map="auto", trust_remote_code=True, ) else: # For CPU: load directly to CPU without device_map model = AutoModelForCausalLM.from_pretrained( MID, torch_dtype=dtype, trust_remote_code=True, ) model = model.to(device) model.eval() # Set to evaluation mode print(f"Model loaded successfully on {device}!") return tok, model @spaces.GPU(duration=60) def caption_image(image, custom_prompt=None): if image is None: return "Please upload an image first." try: # Load model if not already loaded tok, model = load_model() # Convert image to RGB if needed if image.mode != "RGB": image = image.convert("RGB") # Use custom prompt or default prompt = custom_prompt if custom_prompt else "Describe this image in detail." # Build chat message messages = [ {"role": "user", "content": f"\n{prompt}"} ] # Render to string to place token correctly rendered = tok.apply_chat_template( messages, add_generation_prompt=True, tokenize=False ) # Split at image token pre, post = rendered.split("", 1) # Tokenize text around the image token pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids # Get model device and dtype device = next(model.parameters()).device dtype = next(model.parameters()).dtype # Insert IMAGE token id at placeholder position img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype) input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(device) attention_mask = torch.ones_like(input_ids, device=device) # Preprocess image using model's vision tower px = model.get_vision_tower().image_processor( images=image, return_tensors="pt" )["pixel_values"] px = px.to(device, dtype=dtype) # Generate caption with torch.no_grad(): out = model.generate( inputs=input_ids, attention_mask=attention_mask, images=px, max_new_tokens=128, do_sample=False, # Deterministic generation temperature=1.0, ) # Decode and return the generated text generated_text = tok.decode(out[0], skip_special_tokens=True) # Extract only the assistant's response if "assistant" in generated_text: response = generated_text.split("assistant")[-1].strip() else: response = generated_text return response except Exception as e: import traceback error_detail = traceback.format_exc() return f"Error generating caption: {str(e)}\n\nDetails:\n{error_detail}" # Create Gradio interface with gr.Blocks(title="Fal-2 Image Captioning") as demo: gr.Markdown( """ # 🖼️ Fal-2-500M Image Captioning Upload an image to generate a detailed caption using SVECTOR's Fal-2-500M model. You can use the default prompt or provide your own custom prompt. """ ) with gr.Row(): with gr.Column(): image_input = gr.Image( type="pil", label="Upload Image", elem_id="image-upload" ) custom_prompt = gr.Textbox( label="Custom Prompt (Optional)", placeholder="Leave empty for default: 'Describe this image in detail.'", lines=2 ) with gr.Row(): clear_btn = gr.ClearButton([image_input, custom_prompt]) generate_btn = gr.Button("Generate Caption", variant="primary") with gr.Column(): output = gr.Textbox( label="Generated Caption", lines=8, max_lines=15, show_copy_button=True ) # Event handlers generate_btn.click( fn=caption_image, inputs=[image_input, custom_prompt], outputs=output ) # Also generate on image upload if no custom prompt image_input.change( fn=lambda img, prompt: caption_image(img, prompt) if img is not None and not prompt else None, inputs=[image_input, custom_prompt], outputs=output ) gr.Markdown( """ --- **Model:** Fal-2-500M (by SVECTOR) """ ) if __name__ == "__main__": demo.launch( share=False, show_error=True, server_name="0.0.0.0", server_port=7860 )