File size: 4,143 Bytes
7766a5c
 
 
f275d7c
 
7766a5c
 
 
 
09c7c56
 
 
 
 
71b45b9
09c7c56
 
 
7766a5c
 
71b45b9
 
 
851e8b5
09c7c56
71b45b9
 
 
09c7c56
 
 
7766a5c
 
851e8b5
 
 
 
 
 
 
 
 
7766a5c
851e8b5
 
 
7766a5c
 
851e8b5
 
 
 
 
 
 
7766a5c
851e8b5
 
 
7766a5c
851e8b5
 
 
 
 
7766a5c
851e8b5
09c7c56
 
 
7766a5c
 
 
 
 
 
 
 
 
 
09c7c56
 
7766a5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71b45b9
09c7c56
 
 
 
 
851e8b5
71b45b9
7766a5c
851e8b5
71b45b9
 
09c7c56
7766a5c
71b45b9
 
09c7c56
 
 
 
 
71b45b9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import re
import torch
import requests
from io import BytesIO
from PIL import Image, ImageSequence
from transformers import AutoProcessor, LlavaForConditionalGeneration
import gradio as gr

# ---------------------------
# Config
# ---------------------------
MODEL_NAME = "fancyfeast/llama-joycaption-beta-one-hf-llava"
HF_TOKEN = os.getenv("HF_TOKEN")  # optional secret in Space settings

# ---------------------------
# Load model & processor
# ---------------------------
token_arg = {"token": HF_TOKEN} if HF_TOKEN else {}
processor = AutoProcessor.from_pretrained(MODEL_NAME, **token_arg)
llava_model = LlavaForConditionalGeneration.from_pretrained(
    MODEL_NAME,
    device_map="cpu",
    torch_dtype=torch.bfloat16,
    **token_arg,
)
llava_model.eval()

# ---------------------------
# Helpers
# ---------------------------
def download_bytes(url: str, timeout: int = 30) -> bytes:
    resp = requests.get(url, stream=True, timeout=timeout)
    resp.raise_for_status()
    return resp.content

def mp4_to_gif(mp4_bytes: bytes) -> bytes:
    files = {"new-file": ("video.mp4", mp4_bytes, "video/mp4")}
    resp = requests.post(
        "https://s.ezgif.com/video-to-gif",
        files=files,
        data={"file": "video.mp4"},
        timeout=120,
    )
    resp.raise_for_status()
    match = re.search(r'<img[^>]+src="([^"]+\.gif)"', resp.text)
    if not match:
        match = re.search(r'src="([^"]+?/tmp/[^"]+\.gif)"', resp.text)
    if not match:
        raise RuntimeError("Failed to extract GIF URL from ezgif response")
    gif_url = match.group(1)
    if gif_url.startswith("//"):
        gif_url = "https:" + gif_url
    elif gif_url.startswith("/"):
        gif_url = "https://s.ezgif.com" + gif_url
    gif_resp = requests.get(gif_url, timeout=60)
    gif_resp.raise_for_status()
    return gif_resp.content

def load_first_frame_from_bytes(raw: bytes) -> Image.Image:
    img = Image.open(BytesIO(raw))
    if getattr(img, "is_animated", False):
        img = next(ImageSequence.Iterator(img))
    if img.mode != "RGB":
        img = img.convert("RGB")
    return img

# ---------------------------
# Main inference
# ---------------------------
def generate_caption_from_url(url: str, prompt: str = "Describe the image.") -> str:
    if not url:
        return "No URL provided."
    try:
        raw = download_bytes(url)
    except Exception as e:
        return f"Download error: {e}"

    lower = url.lower().split("?")[0]
    try:
        # crude MP4 detection by extension or ftyp box signature
        if lower.endswith(".mp4") or raw[:16].lower().find(b"ftyp") != -1:
            try:
                raw = mp4_to_gif(raw)
            except Exception as e:
                return f"MP4→GIF conversion failed: {e}"
        img = load_first_frame_from_bytes(raw)
    except Exception as e:
        return f"Image processing error: {e}"

    try:
        inputs = processor(images=img, text=prompt, return_tensors="pt")
        inputs = {k: v.to(llava_model.device) for k, v in inputs.items()}
        with torch.no_grad():
            out_ids = llava_model.generate(**inputs, max_new_tokens=128)
        caption = processor.decode(out_ids[0], skip_special_tokens=True)
        return caption
    except Exception as e:
        return f"Inference error: {e}"

# ---------------------------
# Gradio UI (compatible init)
# ---------------------------
# Use try/except to support Gradio versions that don't accept allow_flagging
gradio_kwargs = dict(
    fn=generate_caption_from_url,
    inputs=[
        gr.Textbox(label="Image / GIF / MP4 URL", placeholder="https://example.com/photo.jpg"),
        gr.Textbox(label="Prompt (optional)", value="Describe the image."),
    ],
    outputs=gr.Textbox(label="Generated caption"),
    title="JoyCaption (fancyfeast) - URL input",
    description="Paste a direct link to an image, GIF, or MP4. MP4 files are converted to GIF via ezgif.com; the first frame is captioned.",
)

try:
    iface = gr.Interface(**gradio_kwargs, allow_flagging="never")
except TypeError:
    iface = gr.Interface(**gradio_kwargs)

if __name__ == "__main__":
    iface.launch()