Spaces:
Build error
Build error
| import os | |
| import re | |
| import torch | |
| import requests | |
| from io import BytesIO | |
| from PIL import Image, ImageSequence | |
| from transformers import AutoProcessor, LlavaForConditionalGeneration | |
| import gradio as gr | |
| # --------------------------- | |
| # Config | |
| # --------------------------- | |
| MODEL_NAME = "fancyfeast/llama-joycaption-beta-one-hf-llava" | |
| HF_TOKEN = os.getenv("HF_TOKEN") # optional secret in Space settings | |
| # --------------------------- | |
| # Load model & processor | |
| # --------------------------- | |
| token_arg = {"token": HF_TOKEN} if HF_TOKEN else {} | |
| processor = AutoProcessor.from_pretrained(MODEL_NAME, **token_arg) | |
| llava_model = LlavaForConditionalGeneration.from_pretrained( | |
| MODEL_NAME, | |
| device_map="cpu", | |
| torch_dtype=torch.bfloat16, | |
| **token_arg, | |
| ) | |
| llava_model.eval() | |
| # --------------------------- | |
| # Helpers | |
| # --------------------------- | |
| def download_bytes(url: str, timeout: int = 30) -> bytes: | |
| resp = requests.get(url, stream=True, timeout=timeout) | |
| resp.raise_for_status() | |
| return resp.content | |
| def mp4_to_gif(mp4_bytes: bytes) -> bytes: | |
| files = {"new-file": ("video.mp4", mp4_bytes, "video/mp4")} | |
| resp = requests.post( | |
| "https://s.ezgif.com/video-to-gif", | |
| files=files, | |
| data={"file": "video.mp4"}, | |
| timeout=120, | |
| ) | |
| resp.raise_for_status() | |
| match = re.search(r'<img[^>]+src="([^"]+\.gif)"', resp.text) | |
| if not match: | |
| match = re.search(r'src="([^"]+?/tmp/[^"]+\.gif)"', resp.text) | |
| if not match: | |
| raise RuntimeError("Failed to extract GIF URL from ezgif response") | |
| gif_url = match.group(1) | |
| if gif_url.startswith("//"): | |
| gif_url = "https:" + gif_url | |
| elif gif_url.startswith("/"): | |
| gif_url = "https://s.ezgif.com" + gif_url | |
| gif_resp = requests.get(gif_url, timeout=60) | |
| gif_resp.raise_for_status() | |
| return gif_resp.content | |
| def load_first_frame_from_bytes(raw: bytes) -> Image.Image: | |
| img = Image.open(BytesIO(raw)) | |
| if getattr(img, "is_animated", False): | |
| img = next(ImageSequence.Iterator(img)) | |
| if img.mode != "RGB": | |
| img = img.convert("RGB") | |
| return img | |
| # --------------------------- | |
| # Main inference | |
| # --------------------------- | |
| def generate_caption_from_url(url: str, prompt: str = "Describe the image.") -> str: | |
| if not url: | |
| return "No URL provided." | |
| try: | |
| raw = download_bytes(url) | |
| except Exception as e: | |
| return f"Download error: {e}" | |
| lower = url.lower().split("?")[0] | |
| try: | |
| # crude MP4 detection by extension or ftyp box signature | |
| if lower.endswith(".mp4") or raw[:16].lower().find(b"ftyp") != -1: | |
| try: | |
| raw = mp4_to_gif(raw) | |
| except Exception as e: | |
| return f"MP4→GIF conversion failed: {e}" | |
| img = load_first_frame_from_bytes(raw) | |
| except Exception as e: | |
| return f"Image processing error: {e}" | |
| try: | |
| inputs = processor(images=img, text=prompt, return_tensors="pt") | |
| inputs = {k: v.to(llava_model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| out_ids = llava_model.generate(**inputs, max_new_tokens=128) | |
| caption = processor.decode(out_ids[0], skip_special_tokens=True) | |
| return caption | |
| except Exception as e: | |
| return f"Inference error: {e}" | |
| # --------------------------- | |
| # Gradio UI (compatible init) | |
| # --------------------------- | |
| # Use try/except to support Gradio versions that don't accept allow_flagging | |
| gradio_kwargs = dict( | |
| fn=generate_caption_from_url, | |
| inputs=[ | |
| gr.Textbox(label="Image / GIF / MP4 URL", placeholder="https://example.com/photo.jpg"), | |
| gr.Textbox(label="Prompt (optional)", value="Describe the image."), | |
| ], | |
| outputs=gr.Textbox(label="Generated caption"), | |
| title="JoyCaption (fancyfeast) - URL input", | |
| description="Paste a direct link to an image, GIF, or MP4. MP4 files are converted to GIF via ezgif.com; the first frame is captioned.", | |
| ) | |
| try: | |
| iface = gr.Interface(**gradio_kwargs, allow_flagging="never") | |
| except TypeError: | |
| iface = gr.Interface(**gradio_kwargs) | |
| if __name__ == "__main__": | |
| iface.launch() | |