File size: 12,351 Bytes
cd5ca02
 
 
7766a5c
028a367
cd5ca02
d125cdc
cd5ca02
 
 
e76c937
cd5ca02
 
e4bf697
7766a5c
 
 
cd5ca02
 
 
 
e4bf697
cd5ca02
 
 
 
 
e4bf697
cd5ca02
 
 
 
 
 
 
e4bf697
cd5ca02
 
 
71b45b9
cd5ca02
 
 
 
 
 
 
 
 
 
 
 
 
7766a5c
cd5ca02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
851e8b5
cd5ca02
 
028a367
851e8b5
 
 
 
7766a5c
851e8b5
49d3ba7
cd5ca02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7766a5c
 
 
 
 
 
cd5ca02
 
7766a5c
cd5ca02
 
 
 
 
7766a5c
 
 
 
cd5ca02
7766a5c
cd5ca02
 
 
028a367
 
cd5ca02
 
 
 
 
 
 
 
 
 
028a367
cd5ca02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7766a5c
 
71b45b9
cd5ca02
 
 
 
 
851e8b5
71b45b9
cd5ca02
851e8b5
71b45b9
 
cd5ca02
 
71b45b9
 
cd5ca02
 
 
 
 
71b45b9
cd5ca02
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import io
import re
import sys
import time
import hashlib
import pathlib
import subprocess
from typing import Optional

import requests
from PIL import Image, ImageSequence
import gradio as gr

# If you still want to use HF AutoProcessor / LlavaForConditionalGeneration for decoding,
# keep transformers installed and uncomment the imports below. This file instead uses
# llama-cpp-python for model inference (GGUF).
from transformers import AutoProcessor

# ----------------------------------------------------------------------
# Config: set model URLs and optional checksums
# ----------------------------------------------------------------------
MODEL_DIR = pathlib.Path("model")
MODEL_DIR.mkdir(parents=True, exist_ok=True)

# Replace these with your preferred GGUF files (mradermacher or TheBloke variants)
Q4_K_M_URL = (
    "https://huggingface.co/mradermacher/joycaption-llama/resolve/main/llama-joycaption-q4_k_m.gguf"
)
Q4_K_S_URL = (
    "https://huggingface.co/mradermacher/joycaption-llama/resolve/main/llama-joycaption-q4_k_s.gguf"
)

# Optional: set SHA256 checksums to validate downloads (replace with real values)
Q4_K_M_SHA256: Optional[str] = None
Q4_K_S_SHA256: Optional[str] = None

# Generation params
MAX_NEW_TOKENS = 128
TEMPERATURE = 0.2
TOP_P = 0.95
STOP_STRS = ["\n"]

# HF processor/model name used previously for tokenization/chat template
HF_PROCESSOR_NAME = "fancyfeast/llama-joycaption-beta-one-hf-llava"
HF_TOKEN = os.getenv("HF_TOKEN")  # optional

# ----------------------------------------------------------------------
# Utilities: downloads, checksum, mp4->gif, image load
# ----------------------------------------------------------------------
def download_bytes(url: str, timeout: int = 30) -> bytes:
    with requests.get(url, stream=True, timeout=timeout) as resp:
        resp.raise_for_status()
        return resp.content


def mp4_to_gif(mp4_bytes: bytes) -> bytes:
    files = {"new-file": ("video.mp4", mp4_bytes, "video/mp4")}
    resp = requests.post(
        "https://s.ezgif.com/video-to-gif",
        files=files,
        data={"file": "video.mp4"},
        timeout=120,
    )
    resp.raise_for_status()
    match = re.search(r'<img[^>]+src="([^"]+\.gif)"', resp.text)
    if not match:
        match = re.search(r'src="([^"]+?/tmp/[^"]+\.gif)"', resp.text)
    if not match:
        raise RuntimeError("Failed to extract GIF URL from ezgif response")
    gif_url = match.group(1)
    if gif_url.startswith("//"):
        gif_url = "https:" + gif_url
    elif gif_url.startswith("/"):
        gif_url = "https://s.ezgif.com" + gif_url
    with requests.get(gif_url, timeout=60) as gif_resp:
        gif_resp.raise_for_status()
        return gif_resp.content


def load_first_frame_from_bytes(raw: bytes) -> Image.Image:
    img = Image.open(io.BytesIO(raw))
    if getattr(img, "is_animated", False):
        img = next(ImageSequence.Iterator(img))
    if img.mode != "RGB":
        img = img.convert("RGB")
    return img


def sha256_of_file(path: pathlib.Path) -> str:
    h = hashlib.sha256()
    with open(path, "rb") as f:
        for block in iter(lambda: f.read(65536), b""):
            h.update(block)
    return h.hexdigest()


def download_file(url: str, dest: pathlib.Path, expected_sha256: Optional[str] = None) -> None:
    if dest.is_file():
        if expected_sha256:
            try:
                if sha256_of_file(dest) == expected_sha256:
                    return
            except Exception:
                pass
        # remove possibly corrupted/old file
        dest.unlink()
    print(f"Downloading model from {url} -> {dest}")
    with requests.get(url, stream=True, timeout=120) as r:
        r.raise_for_status()
        total = int(r.headers.get("content-length", 0) or 0)
        downloaded = 0
        with open(dest, "wb") as f:
            for chunk in r.iter_content(chunk_size=8192):
                if not chunk:
                    continue
                f.write(chunk)
                downloaded += len(chunk)
                if total:
                    pct = downloaded * 100 // total
                    print(f"\r{dest.name}: {pct}% ", end="", flush=True)
    print()
    if expected_sha256:
        got = sha256_of_file(dest)
        if got != expected_sha256:
            raise ValueError(f"Checksum mismatch for {dest}: got {got}, expected {expected_sha256}")


# ----------------------------------------------------------------------
# llama-cpp loading + automated rebuild
# ----------------------------------------------------------------------
def rebuild_llama_cpp() -> None:
    env = os.environ.copy()
    env["PIP_NO_BINARY"] = "llama-cpp-python"
    # upgrade pip then reinstall
    subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "pip"], env=env)
    subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "cmake", "wheel", "setuptools"], env=env)
    subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "llama-cpp-python"], env=env)


def try_load_gguf() -> "llama_cpp.Llama":
    """
    Download Q4_K_M then Q4_K_S and attempt to load with llama_cpp.Llama.
    If both fail, rebuild llama-cpp-python from source and retry once.
    """
    import importlib
    from pathlib import Path

    candidates = [
        (Q4_K_M_URL, MODEL_DIR / "llama-joycaption-q4_k_m.gguf", Q4_K_M_SHA256),
        (Q4_K_S_URL, MODEL_DIR / "llama-joycaption-q4_k_s.gguf", Q4_K_S_SHA256),
    ]

    last_exc = None

    for url, path, sha in candidates:
        try:
            download_file(url, path, expected_sha256=sha)
            print(f"Attempting to load GGUF: {path}")
            # lazy import so we catch import-time errors before rebuild attempt
            llama_cpp = importlib.import_module("llama_cpp")
            Llama = getattr(llama_cpp, "Llama")
            # minimal params; adjust n_ctx or gpu settings if available
            lm = Llama(model_path=str(path), n_ctx=2048, n_gpu_layers=0, verbose=False)
            print("Model loaded successfully.")
            return lm
        except Exception as e:
            print(f"Loading {path.name} failed: {e}")
            last_exc = e

    # If both failed, attempt a rebuild then retry first candidate once
    try:
        print("Both GGUF variants failed to load. Rebuilding llama-cpp-python from source...")
        rebuild_llama_cpp()
    except Exception as e:
        print(f"Rebuild failed: {e}")
        raise last_exc or e

    # After rebuild, import & load primary model
    try:
        import importlib

        llama_cpp = importlib.reload(importlib.import_module("llama_cpp"))
        Llama = getattr(llama_cpp, "Llama")
        path = candidates[0][1]
        if not path.is_file():
            download_file(candidates[0][0], path, expected_sha256=candidates[0][2])
        lm = Llama(model_path=str(path), n_ctx=2048, n_gpu_layers=0, verbose=False)
        print("Model loaded successfully after rebuild.")
        return lm
    except Exception as e:
        print(f"Load after rebuild failed: {e}")
        raise e


# ----------------------------------------------------------------------
# Processor and model wrapper
# ----------------------------------------------------------------------
# We keep AutoProcessor to reuse the chat template behaviour you used previously.
processor = AutoProcessor.from_pretrained(
    HF_PROCESSOR_NAME,
    trust_remote_code=True,
    num_additional_image_tokens=1,
    **({} if not HF_TOKEN else {"token": HF_TOKEN}),
)

# Lazy model holder
class ModelWrapper:
    def __init__(self):
        self.llm = None  # llama-cpp Llama instance

    def ensure_model(self):
        if self.llm is None:
            self.llm = try_load_gguf()

    def generate(self, prompt: str, max_new_tokens: int = MAX_NEW_TOKENS):
        self.ensure_model()
        # llama-cpp-python call style: model(prompt=..., max_tokens=..., temperature=..., top_p=..., stop=...)
        out = self.llm(prompt, max_tokens=max_new_tokens, temperature=TEMPERATURE, top_p=TOP_P, stop=STOP_STRS)
        # llama-cpp-python responses usually in out["choices"][0]["text"]
        txt = out.get("choices", [{}])[0].get("text", "")
        return txt

MODEL = ModelWrapper()

# ----------------------------------------------------------------------
# Inference: convert URL->image, build prompt via processor chat template, run llama-cpp
# ----------------------------------------------------------------------
def generate_caption_from_url(url: str, prompt: str = "Describe the image.") -> str:
    if not url:
        return "No URL provided."
    try:
        raw = download_bytes(url)
    except Exception as e:
        return f"Download error: {e}"

    lower = url.lower().split("?")[0]
    try:
        if lower.endswith(".mp4") or raw[:16].lower().find(b"ftyp") != -1:
            try:
                raw = mp4_to_gif(raw)
            except Exception as e:
                return f"MP4→GIF conversion failed: {e}"
        img = load_first_frame_from_bytes(raw)
    except Exception as e:
        return f"Image processing error: {e}"

    # Resize to a conservative size (512) expected by many VLMs
    try:
        img = img.resize((512, 512), resample=Image.BICUBIC)
    except Exception:
        pass

    try:
        # Produce conversation so the processor inserts image token correctly
        conversation = [
            {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}
        ]
        inputs = processor.apply_chat_template(
            conversation,
            add_generation_prompt=True,
            return_tensors="pt",
            return_dict=True,
            images=img,
        )

        # The processor provides a textual input (input_ids). We'll decode it to a plain prompt
        # string to feed llama-cpp. The processor has a `decode` helper; else we build a simple prompt.
        # Use processor.tokenizer if available to decode input_ids -> text.
        text_prompt = None
        if hasattr(processor, "tokenizer") and getattr(inputs, "input_ids", None) is not None:
            try:
                # inputs may be dict tensors; extract CPU numpy/torch then decode
                input_ids = inputs["input_ids"][0]
                # convert to list of ints if tensor
                import torch
                if hasattr(input_ids, "cpu"):
                    ids = input_ids.cpu().numpy().tolist()
                else:
                    ids = list(input_ids)
                text_prompt = processor.tokenizer.decode(ids, skip_special_tokens=True)
            except Exception:
                text_prompt = None

        if not text_prompt:
            # Fallback: simple textual template with a tag where the image is referenced.
            text_prompt = f"<img> [image here] </img>\n{prompt}\nAnswer:"

        # Debug prints (Space logs)
        print("Prompt to model (truncated):", text_prompt[:512].replace("\n", "\\n"))

        out_text = MODEL.generate(text_prompt, max_new_tokens=MAX_NEW_TOKENS)
        # Postprocess: strip, remove accidental stop tokens, etc.
        return out_text.strip()
    except Exception as e:
        return f"Inference error: {e}"


# ----------------------------------------------------------------------
# Gradio UI (URL + prompt -> text)
# ----------------------------------------------------------------------
gradio_kwargs = dict(
    fn=generate_caption_from_url,
    inputs=[
        gr.Textbox(label="Image / GIF / MP4 URL", placeholder="https://example.com/photo.jpg"),
        gr.Textbox(label="Prompt (optional)", value="Describe the image."),
    ],
    outputs=gr.Textbox(label="Generated caption"),
    title="JoyCaption - URL input (GGUF + auto-rebuild)",
    description="Paste a direct link to an image/GIF/MP4 (MP4 will be converted).",
)

try:
    iface = gr.Interface(**gradio_kwargs, allow_flagging="never")
except TypeError:
    iface = gr.Interface(**gradio_kwargs)

if __name__ == "__main__":
    try:
        iface.launch(server_name="0.0.0.0", server_port=7860)
    finally:
        try:
            import asyncio
            loop = asyncio.get_event_loop()
            if not loop.is_closed():
                loop.close()
        except Exception:
            pass