File size: 4,889 Bytes
7766a5c
028a367
7aed240
7766a5c
f275d7c
7766a5c
 
 
 
09c7c56
028a367
71b45b9
7766a5c
028a367
 
 
851e8b5
 
 
 
 
 
 
7766a5c
851e8b5
 
 
7766a5c
 
851e8b5
 
 
 
 
 
 
028a367
 
 
851e8b5
7766a5c
028a367
851e8b5
 
 
 
7766a5c
851e8b5
028a367
 
 
 
 
7aed240
028a367
 
35d219a
028a367
 
 
7aed240
028a367
 
 
 
 
7766a5c
 
 
 
 
 
 
 
 
 
09c7c56
7766a5c
 
 
 
 
 
 
 
35d219a
7766a5c
028a367
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35d219a
028a367
 
 
 
 
7766a5c
 
 
 
 
 
71b45b9
09c7c56
851e8b5
71b45b9
7766a5c
851e8b5
71b45b9
 
028a367
 
71b45b9
 
09c7c56
 
 
 
 
71b45b9
028a367
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import io
import re
import torch
import requests
from PIL import Image, ImageSequence
from transformers import AutoProcessor, LlavaForConditionalGeneration
import gradio as gr

MODEL_NAME = "fancyfeast/llama-joycaption-beta-one-hf-llava"
HF_TOKEN = os.getenv("HF_TOKEN")  # optional

def download_bytes(url: str, timeout: int = 30) -> bytes:
    with requests.get(url, stream=True, timeout=timeout) as resp:
        resp.raise_for_status()
        return resp.content

def mp4_to_gif(mp4_bytes: bytes) -> bytes:
    files = {"new-file": ("video.mp4", mp4_bytes, "video/mp4")}
    resp = requests.post(
        "https://s.ezgif.com/video-to-gif",
        files=files,
        data={"file": "video.mp4"},
        timeout=120,
    )
    resp.raise_for_status()
    match = re.search(r'<img[^>]+src="([^"]+\.gif)"', resp.text)
    if not match:
        match = re.search(r'src="([^"]+?/tmp/[^"]+\.gif)"', resp.text)
    if not match:
        raise RuntimeError("Failed to extract GIF URL from ezgif response")
    gif_url = match.group(1)
    if gif_url.startswith("//"):
        gif_url = "https:" + gif_url
    elif gif_url.startswith("/"):
        gif_url = "https://s.ezgif.com" + gif_url
    with requests.get(gif_url, timeout=60) as gif_resp:
        gif_resp.raise_for_status()
        return gif_resp.content

def load_first_frame_from_bytes(raw: bytes) -> Image.Image:
    img = Image.open(io.BytesIO(raw))
    if getattr(img, "is_animated", False):
        img = next(ImageSequence.Iterator(img))
    if img.mode != "RGB":
        img = img.convert("RGB")
    return img

# Load processor + model
token_arg = {"use_auth_token": HF_TOKEN} if HF_TOKEN else {}
processor = AutoProcessor.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    num_additional_image_tokens=1,
    **({} if not HF_TOKEN else {"token": HF_TOKEN})
)
# CPU Space -> use float32
llava_model = LlavaForConditionalGeneration.from_pretrained(
    MODEL_NAME,
    device_map="cpu",
    torch_dtype=torch.float32,
    trust_remote_code=True,
    **({} if not HF_TOKEN else {"token": HF_TOKEN})
)
llava_model.eval()

def generate_caption_from_url(url: str, prompt: str = "Describe the image.") -> str:
    if not url:
        return "No URL provided."
    try:
        raw = download_bytes(url)
    except Exception as e:
        return f"Download error: {e}"

    lower = url.lower().split("?")[0]
    try:
        if lower.endswith(".mp4") or raw[:16].lower().find(b"ftyp") != -1:
            try:
                raw = mp4_to_gif(raw)
            except Exception as e:
                return f"MP4→GIF conversion failed: {e}"
        img = load_first_frame_from_bytes(raw)
    except Exception as e:
        return f"Image processing error: {e}"

    # Resize to conservative default
    try:
        img = img.resize((512, 512), resample=Image.BICUBIC)
    except Exception:
        pass

    try:
        conversation = [
            {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}
        ]
        inputs = processor.apply_chat_template(
            conversation,
            add_generation_prompt=True,
            return_tensors="pt",
            return_dict=True,
            images=img,
        )
        device = llava_model.device
        inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
        if "pixel_values" in inputs:
            inputs["pixel_values"] = inputs["pixel_values"].to(dtype=llava_model.dtype, device=device)

        # Minimal debug info (appears in Space logs)
        if "pixel_values" in inputs:
            print("pixel_values.shape:", inputs["pixel_values"].shape)
        if "input_ids" in inputs:
            print("input_ids.shape:", inputs["input_ids"].shape)

        with torch.no_grad():
            out_ids = llava_model.generate(**inputs, max_new_tokens=128)
        caption = processor.decode(out_ids[0], skip_special_tokens=True)
        return caption
    except Exception as e:
        return f"Inference error: {e}"

gradio_kwargs = dict(
    fn=generate_caption_from_url,
    inputs=[
        gr.Textbox(label="Image / GIF / MP4 URL", placeholder="https://example.com/photo.jpg"),
        gr.Textbox(label="Prompt (optional)", value="Describe the image."),
    ],
    outputs=gr.Textbox(label="Generated caption"),
    title="JoyCaption - URL input",
    description="Paste a direct link to an image/GIF/MP4 (MP4 will be converted).",
)

try:
    iface = gr.Interface(**gradio_kwargs, allow_flagging="never")
except TypeError:
    iface = gr.Interface(**gradio_kwargs)

if __name__ == "__main__":
    try:
        iface.launch(server_name="0.0.0.0", server_port=7860)
    finally:
        try:
            import asyncio
            loop = asyncio.get_event_loop()
            if not loop.is_closed():
                loop.close()
        except Exception:
            pass