Spaces:
Build error
Build error
| import os, torch | |
| from transformers import AutoProcessor, LlavaForConditionalGeneration | |
| import gradio as gr | |
| from PIL import Image, ImageSequence | |
| import requests | |
| from io import BytesIO | |
| # ---- 1️⃣ Use a public repo ---- | |
| MODEL_NAME = "llava-hf/joycaption-llama3.1-8b" # public version | |
| processor = AutoProcessor.from_pretrained(MODEL_NAME) | |
| llava_model = LlavaForConditionalGeneration.from_pretrained( | |
| MODEL_NAME, | |
| device_map="cpu", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| llava_model.eval() | |
| # ------------------------------------------------- | |
| # Helper: download a file from a URL | |
| # ------------------------------------------------- | |
| def download_bytes(url: str) -> bytes: | |
| resp = requests.get(url, stream=True, timeout=30) | |
| resp.raise_for_status() | |
| return resp.content | |
| # ------------------------------------------------- | |
| # Helper: convert MP4 → GIF using ezgif.com (public API) | |
| # ------------------------------------------------- | |
| def mp4_to_gif(mp4_bytes: bytes) -> bytes: | |
| """ | |
| Sends the MP4 bytes to ezgif.com and returns the resulting GIF bytes. | |
| The API is undocumented but works via a simple multipart POST. | |
| """ | |
| files = {"new-file": ("video.mp4", mp4_bytes, "video/mp4")} | |
| # ezgif.com endpoint for MP4 → GIF conversion | |
| resp = requests.post( | |
| "https://s.ezgif.com/video-to-gif", | |
| files=files, | |
| data={"file": "video.mp4"}, | |
| timeout=60, | |
| ) | |
| resp.raise_for_status() | |
| # The response HTML contains a link to the generated GIF. | |
| # We extract the first <img src="..."> that ends with .gif | |
| import re | |
| match = re.search(r'<img[^>]+src="([^"]+\.gif)"', resp.text) | |
| if not match: | |
| raise RuntimeError("Failed to extract GIF URL from ezgif response") | |
| gif_url = match.group(1) | |
| # ezgif serves the GIF from a relative path; make it absolute | |
| if gif_url.startswith("//"): | |
| gif_url = "https:" + gif_url | |
| elif gif_url.startswith("/"): | |
| gif_url = "https://s.ezgif.com" + gif_url | |
| gif_resp = requests.get(gif_url, timeout=30) | |
| gif_resp.raise_for_status() | |
| return gif_resp.content | |
| # ------------------------------------------------- | |
| # Main inference function | |
| # ------------------------------------------------- | |
| def generate_caption_from_url(url: str, prompt: str = "Describe the image.") -> str: | |
| """ | |
| 1. Download the resource. | |
| 2. If it is an MP4 → convert to GIF. | |
| 3. Load the first frame of the image/GIF. | |
| 4. Run JoyCaption and return the caption. | |
| """ | |
| # ----------------------------------------------------------------- | |
| # 1️⃣ Download raw bytes | |
| # ----------------------------------------------------------------- | |
| raw = download_bytes(url) | |
| # ----------------------------------------------------------------- | |
| # 2️⃣ Determine type & possibly convert MP4 → GIF | |
| # ----------------------------------------------------------------- | |
| lower_url = url.lower() | |
| if lower_url.endswith(".mp4"): | |
| # Convert video to GIF | |
| raw = mp4_to_gif(raw) | |
| # After conversion we treat it as a GIF | |
| lower_url = ".gif" | |
| # ----------------------------------------------------------------- | |
| # 3️⃣ Load image (first frame for GIFs) | |
| # ----------------------------------------------------------------- | |
| img = Image.open(BytesIO(raw)) | |
| # If the file is a multi‑frame GIF, pick the first frame | |
| if getattr(img, "is_animated", False): | |
| img = next(ImageSequence.Iterator(img)) | |
| # Ensure RGB (JoyCaption expects 3‑channel images) | |
| if img.mode != "RGB": | |
| img = img.convert("RGB") | |
| # ----------------------------------------------------------------- | |
| # 4️⃣ Run the model | |
| # ----------------------------------------------------------------- | |
| inputs = processor(images=img, text=prompt, return_tensors="pt") | |
| inputs = {k: v.to(llava_model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| out_ids = llava_model.generate(**inputs, max_new_tokens=64) | |
| caption = processor.decode(out_ids[0], skip_special_tokens=True) | |
| return caption | |
| # ------------------------------------------------- | |
| # Gradio UI | |
| # ------------------------------------------------- | |
| iface = gr.Interface( | |
| fn=generate_caption_from_url, | |
| inputs=[ | |
| gr.Textbox( | |
| label="Image / GIF / MP4 URL", | |
| placeholder="https://example.com/photo.jpg or https://example.com/clip.mp4", | |
| ), | |
| gr.Textbox(label="Prompt (optional)", value="Describe the image."), | |
| ], | |
| outputs=gr.Textbox(label="Generated caption"), | |
| title="JoyCaption – URL input (supports GIF & MP4)", | |
| description=( | |
| "Enter a direct URL to an image, an animated GIF, or an MP4 video. " | |
| "MP4 files are automatically converted to GIF via ezgif.com, " | |
| "and the first frame of the GIF is captioned." | |
| ), | |
| allow_flagging="never", | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |