Hug0endob's picture
Update app.py
f275d7c verified
raw
history blame
4.95 kB
import os, torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
import gradio as gr
from PIL import Image, ImageSequence
import requests
from io import BytesIO
# ---- 1️⃣ Use a public repo ----
MODEL_NAME = "llava-hf/joycaption-llama3.1-8b" # public version
processor = AutoProcessor.from_pretrained(MODEL_NAME)
llava_model = LlavaForConditionalGeneration.from_pretrained(
MODEL_NAME,
device_map="cpu",
torch_dtype=torch.bfloat16,
)
llava_model.eval()
# -------------------------------------------------
# Helper: download a file from a URL
# -------------------------------------------------
def download_bytes(url: str) -> bytes:
resp = requests.get(url, stream=True, timeout=30)
resp.raise_for_status()
return resp.content
# -------------------------------------------------
# Helper: convert MP4 → GIF using ezgif.com (public API)
# -------------------------------------------------
def mp4_to_gif(mp4_bytes: bytes) -> bytes:
"""
Sends the MP4 bytes to ezgif.com and returns the resulting GIF bytes.
The API is undocumented but works via a simple multipart POST.
"""
files = {"new-file": ("video.mp4", mp4_bytes, "video/mp4")}
# ezgif.com endpoint for MP4 → GIF conversion
resp = requests.post(
"https://s.ezgif.com/video-to-gif",
files=files,
data={"file": "video.mp4"},
timeout=60,
)
resp.raise_for_status()
# The response HTML contains a link to the generated GIF.
# We extract the first <img src="..."> that ends with .gif
import re
match = re.search(r'<img[^>]+src="([^"]+\.gif)"', resp.text)
if not match:
raise RuntimeError("Failed to extract GIF URL from ezgif response")
gif_url = match.group(1)
# ezgif serves the GIF from a relative path; make it absolute
if gif_url.startswith("//"):
gif_url = "https:" + gif_url
elif gif_url.startswith("/"):
gif_url = "https://s.ezgif.com" + gif_url
gif_resp = requests.get(gif_url, timeout=30)
gif_resp.raise_for_status()
return gif_resp.content
# -------------------------------------------------
# Main inference function
# -------------------------------------------------
def generate_caption_from_url(url: str, prompt: str = "Describe the image.") -> str:
"""
1. Download the resource.
2. If it is an MP4 → convert to GIF.
3. Load the first frame of the image/GIF.
4. Run JoyCaption and return the caption.
"""
# -----------------------------------------------------------------
# 1️⃣ Download raw bytes
# -----------------------------------------------------------------
raw = download_bytes(url)
# -----------------------------------------------------------------
# 2️⃣ Determine type & possibly convert MP4 → GIF
# -----------------------------------------------------------------
lower_url = url.lower()
if lower_url.endswith(".mp4"):
# Convert video to GIF
raw = mp4_to_gif(raw)
# After conversion we treat it as a GIF
lower_url = ".gif"
# -----------------------------------------------------------------
# 3️⃣ Load image (first frame for GIFs)
# -----------------------------------------------------------------
img = Image.open(BytesIO(raw))
# If the file is a multi‑frame GIF, pick the first frame
if getattr(img, "is_animated", False):
img = next(ImageSequence.Iterator(img))
# Ensure RGB (JoyCaption expects 3‑channel images)
if img.mode != "RGB":
img = img.convert("RGB")
# -----------------------------------------------------------------
# 4️⃣ Run the model
# -----------------------------------------------------------------
inputs = processor(images=img, text=prompt, return_tensors="pt")
inputs = {k: v.to(llava_model.device) for k, v in inputs.items()}
with torch.no_grad():
out_ids = llava_model.generate(**inputs, max_new_tokens=64)
caption = processor.decode(out_ids[0], skip_special_tokens=True)
return caption
# -------------------------------------------------
# Gradio UI
# -------------------------------------------------
iface = gr.Interface(
fn=generate_caption_from_url,
inputs=[
gr.Textbox(
label="Image / GIF / MP4 URL",
placeholder="https://example.com/photo.jpg or https://example.com/clip.mp4",
),
gr.Textbox(label="Prompt (optional)", value="Describe the image."),
],
outputs=gr.Textbox(label="Generated caption"),
title="JoyCaption – URL input (supports GIF & MP4)",
description=(
"Enter a direct URL to an image, an animated GIF, or an MP4 video. "
"MP4 files are automatically converted to GIF via ezgif.com, "
"and the first frame of the GIF is captioned."
),
allow_flagging="never",
)
if __name__ == "__main__":
iface.launch()