import os import io import re import torch import requests from PIL import Image, ImageSequence from transformers import AutoProcessor, LlavaForConditionalGeneration import gradio as gr MODEL_NAME = "fancyfeast/llama-joycaption-beta-one-hf-llava" HF_TOKEN = os.getenv("HF_TOKEN") # optional def download_bytes(url: str, timeout: int = 30) -> bytes: with requests.get(url, stream=True, timeout=timeout) as resp: resp.raise_for_status() return resp.content def mp4_to_gif(mp4_bytes: bytes) -> bytes: files = {"new-file": ("video.mp4", mp4_bytes, "video/mp4")} resp = requests.post( "https://s.ezgif.com/video-to-gif", files=files, data={"file": "video.mp4"}, timeout=120, ) resp.raise_for_status() match = re.search(r']+src="([^"]+\.gif)"', resp.text) if not match: match = re.search(r'src="([^"]+?/tmp/[^"]+\.gif)"', resp.text) if not match: raise RuntimeError("Failed to extract GIF URL from ezgif response") gif_url = match.group(1) if gif_url.startswith("//"): gif_url = "https:" + gif_url elif gif_url.startswith("/"): gif_url = "https://s.ezgif.com" + gif_url with requests.get(gif_url, timeout=60) as gif_resp: gif_resp.raise_for_status() return gif_resp.content def load_first_frame_from_bytes(raw: bytes) -> Image.Image: img = Image.open(io.BytesIO(raw)) if getattr(img, "is_animated", False): img = next(ImageSequence.Iterator(img)) if img.mode != "RGB": img = img.convert("RGB") return img # Load processor + model token_arg = {"use_auth_token": HF_TOKEN} if HF_TOKEN else {} processor = AutoProcessor.from_pretrained( MODEL_NAME, trust_remote_code=True, num_additional_image_tokens=1, **({} if not HF_TOKEN else {"token": HF_TOKEN}) ) # CPU Space -> use float32 llava_model = LlavaForConditionalGeneration.from_pretrained( MODEL_NAME, device_map="cpu", torch_dtype=torch.float32, trust_remote_code=True, **({} if not HF_TOKEN else {"token": HF_TOKEN}) ) llava_model.eval() def generate_caption_from_url(url: str, prompt: str = "Describe the image.") -> str: if not url: return "No URL provided." try: raw = download_bytes(url) except Exception as e: return f"Download error: {e}" lower = url.lower().split("?")[0] try: if lower.endswith(".mp4") or raw[:16].lower().find(b"ftyp") != -1: try: raw = mp4_to_gif(raw) except Exception as e: return f"MP4→GIF conversion failed: {e}" img = load_first_frame_from_bytes(raw) except Exception as e: return f"Image processing error: {e}" # Resize to conservative default try: img = img.resize((512, 512), resample=Image.BICUBIC) except Exception: pass try: conversation = [ {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]} ] inputs = processor.apply_chat_template( conversation, add_generation_prompt=True, return_tensors="pt", return_dict=True, images=img, ) device = llava_model.device inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()} if "pixel_values" in inputs: inputs["pixel_values"] = inputs["pixel_values"].to(dtype=llava_model.dtype, device=device) # Minimal debug info (appears in Space logs) if "pixel_values" in inputs: print("pixel_values.shape:", inputs["pixel_values"].shape) if "input_ids" in inputs: print("input_ids.shape:", inputs["input_ids"].shape) with torch.no_grad(): out_ids = llava_model.generate(**inputs, max_new_tokens=128) caption = processor.decode(out_ids[0], skip_special_tokens=True) return caption except Exception as e: return f"Inference error: {e}" gradio_kwargs = dict( fn=generate_caption_from_url, inputs=[ gr.Textbox(label="Image / GIF / MP4 URL", placeholder="https://example.com/photo.jpg"), gr.Textbox(label="Prompt (optional)", value="Describe the image."), ], outputs=gr.Textbox(label="Generated caption"), title="JoyCaption - URL input", description="Paste a direct link to an image/GIF/MP4 (MP4 will be converted).", ) try: iface = gr.Interface(**gradio_kwargs, allow_flagging="never") except TypeError: iface = gr.Interface(**gradio_kwargs) if __name__ == "__main__": try: iface.launch(server_name="0.0.0.0", server_port=7860) finally: try: import asyncio loop = asyncio.get_event_loop() if not loop.is_closed(): loop.close() except Exception: pass