import os import re import torch import requests from io import BytesIO from PIL import Image, ImageSequence from transformers import AutoProcessor, LlavaForConditionalGeneration import gradio as gr # --------------------------- # Config # --------------------------- MODEL_NAME = "fancyfeast/llama-joycaption-beta-one-hf-llava" HF_TOKEN = os.getenv("HF_TOKEN") # optional secret in Space settings # --------------------------- # Load model & processor # --------------------------- token_arg = {"token": HF_TOKEN} if HF_TOKEN else {} processor = AutoProcessor.from_pretrained(MODEL_NAME, **token_arg) llava_model = LlavaForConditionalGeneration.from_pretrained( MODEL_NAME, device_map="cpu", torch_dtype=torch.bfloat16, **token_arg, ) llava_model.eval() # --------------------------- # Helpers # --------------------------- def download_bytes(url: str, timeout: int = 30) -> bytes: resp = requests.get(url, stream=True, timeout=timeout) resp.raise_for_status() return resp.content def mp4_to_gif(mp4_bytes: bytes) -> bytes: files = {"new-file": ("video.mp4", mp4_bytes, "video/mp4")} resp = requests.post( "https://s.ezgif.com/video-to-gif", files=files, data={"file": "video.mp4"}, timeout=120, ) resp.raise_for_status() match = re.search(r']+src="([^"]+\.gif)"', resp.text) if not match: match = re.search(r'src="([^"]+?/tmp/[^"]+\.gif)"', resp.text) if not match: raise RuntimeError("Failed to extract GIF URL from ezgif response") gif_url = match.group(1) if gif_url.startswith("//"): gif_url = "https:" + gif_url elif gif_url.startswith("/"): gif_url = "https://s.ezgif.com" + gif_url gif_resp = requests.get(gif_url, timeout=60) gif_resp.raise_for_status() return gif_resp.content def load_first_frame_from_bytes(raw: bytes) -> Image.Image: img = Image.open(BytesIO(raw)) if getattr(img, "is_animated", False): img = next(ImageSequence.Iterator(img)) if img.mode != "RGB": img = img.convert("RGB") return img # --------------------------- # Main inference # --------------------------- def generate_caption_from_url(url: str, prompt: str = "Describe the image.") -> str: if not url: return "No URL provided." try: raw = download_bytes(url) except Exception as e: return f"Download error: {e}" lower = url.lower().split("?")[0] try: # crude MP4 detection by extension or ftyp box signature if lower.endswith(".mp4") or raw[:16].lower().find(b"ftyp") != -1: try: raw = mp4_to_gif(raw) except Exception as e: return f"MP4→GIF conversion failed: {e}" img = load_first_frame_from_bytes(raw) except Exception as e: return f"Image processing error: {e}" try: inputs = processor(images=img, text=prompt, return_tensors="pt") inputs = {k: v.to(llava_model.device) for k, v in inputs.items()} with torch.no_grad(): out_ids = llava_model.generate(**inputs, max_new_tokens=128) caption = processor.decode(out_ids[0], skip_special_tokens=True) return caption except Exception as e: return f"Inference error: {e}" # --------------------------- # Gradio UI (compatible init) # --------------------------- # Use try/except to support Gradio versions that don't accept allow_flagging gradio_kwargs = dict( fn=generate_caption_from_url, inputs=[ gr.Textbox(label="Image / GIF / MP4 URL", placeholder="https://example.com/photo.jpg"), gr.Textbox(label="Prompt (optional)", value="Describe the image."), ], outputs=gr.Textbox(label="Generated caption"), title="JoyCaption (fancyfeast) - URL input", description="Paste a direct link to an image, GIF, or MP4. MP4 files are converted to GIF via ezgif.com; the first frame is captioned.", ) try: iface = gr.Interface(**gradio_kwargs, allow_flagging="never") except TypeError: iface = gr.Interface(**gradio_kwargs) if __name__ == "__main__": iface.launch()