#!/usr/bin/env python3 # -*- coding: utf-8 -*- import os import io import re import sys import time import hashlib import pathlib import subprocess from typing import Optional import requests from PIL import Image, ImageSequence import gradio as gr # If you still want to use HF AutoProcessor / LlavaForConditionalGeneration for decoding, # keep transformers installed and uncomment the imports below. This file instead uses # llama-cpp-python for model inference (GGUF). from transformers import AutoProcessor # ---------------------------------------------------------------------- # Config: set model URLs and optional checksums # ---------------------------------------------------------------------- MODEL_DIR = pathlib.Path("model") MODEL_DIR.mkdir(parents=True, exist_ok=True) # Replace these with your preferred GGUF files (mradermacher or TheBloke variants) Q4_K_M_URL = ( "https://huggingface.co/mradermacher/joycaption-llama/resolve/main/llama-joycaption-q4_k_m.gguf" ) Q4_K_S_URL = ( "https://huggingface.co/mradermacher/joycaption-llama/resolve/main/llama-joycaption-q4_k_s.gguf" ) # Optional: set SHA256 checksums to validate downloads (replace with real values) Q4_K_M_SHA256: Optional[str] = None Q4_K_S_SHA256: Optional[str] = None # Generation params MAX_NEW_TOKENS = 128 TEMPERATURE = 0.2 TOP_P = 0.95 STOP_STRS = ["\n"] # HF processor/model name used previously for tokenization/chat template HF_PROCESSOR_NAME = "fancyfeast/llama-joycaption-beta-one-hf-llava" HF_TOKEN = os.getenv("HF_TOKEN") # optional # ---------------------------------------------------------------------- # Utilities: downloads, checksum, mp4->gif, image load # ---------------------------------------------------------------------- def download_bytes(url: str, timeout: int = 30) -> bytes: with requests.get(url, stream=True, timeout=timeout) as resp: resp.raise_for_status() return resp.content def mp4_to_gif(mp4_bytes: bytes) -> bytes: files = {"new-file": ("video.mp4", mp4_bytes, "video/mp4")} resp = requests.post( "https://s.ezgif.com/video-to-gif", files=files, data={"file": "video.mp4"}, timeout=120, ) resp.raise_for_status() match = re.search(r']+src="([^"]+\.gif)"', resp.text) if not match: match = re.search(r'src="([^"]+?/tmp/[^"]+\.gif)"', resp.text) if not match: raise RuntimeError("Failed to extract GIF URL from ezgif response") gif_url = match.group(1) if gif_url.startswith("//"): gif_url = "https:" + gif_url elif gif_url.startswith("/"): gif_url = "https://s.ezgif.com" + gif_url with requests.get(gif_url, timeout=60) as gif_resp: gif_resp.raise_for_status() return gif_resp.content def load_first_frame_from_bytes(raw: bytes) -> Image.Image: img = Image.open(io.BytesIO(raw)) if getattr(img, "is_animated", False): img = next(ImageSequence.Iterator(img)) if img.mode != "RGB": img = img.convert("RGB") return img def sha256_of_file(path: pathlib.Path) -> str: h = hashlib.sha256() with open(path, "rb") as f: for block in iter(lambda: f.read(65536), b""): h.update(block) return h.hexdigest() def download_file(url: str, dest: pathlib.Path, expected_sha256: Optional[str] = None) -> None: if dest.is_file(): if expected_sha256: try: if sha256_of_file(dest) == expected_sha256: return except Exception: pass # remove possibly corrupted/old file dest.unlink() print(f"Downloading model from {url} -> {dest}") with requests.get(url, stream=True, timeout=120) as r: r.raise_for_status() total = int(r.headers.get("content-length", 0) or 0) downloaded = 0 with open(dest, "wb") as f: for chunk in r.iter_content(chunk_size=8192): if not chunk: continue f.write(chunk) downloaded += len(chunk) if total: pct = downloaded * 100 // total print(f"\r{dest.name}: {pct}% ", end="", flush=True) print() if expected_sha256: got = sha256_of_file(dest) if got != expected_sha256: raise ValueError(f"Checksum mismatch for {dest}: got {got}, expected {expected_sha256}") # ---------------------------------------------------------------------- # llama-cpp loading + automated rebuild # ---------------------------------------------------------------------- def rebuild_llama_cpp() -> None: env = os.environ.copy() env["PIP_NO_BINARY"] = "llama-cpp-python" # upgrade pip then reinstall subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "pip"], env=env) subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "cmake", "wheel", "setuptools"], env=env) subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "llama-cpp-python"], env=env) def try_load_gguf() -> "llama_cpp.Llama": """ Download Q4_K_M then Q4_K_S and attempt to load with llama_cpp.Llama. If both fail, rebuild llama-cpp-python from source and retry once. """ import importlib from pathlib import Path candidates = [ (Q4_K_M_URL, MODEL_DIR / "llama-joycaption-q4_k_m.gguf", Q4_K_M_SHA256), (Q4_K_S_URL, MODEL_DIR / "llama-joycaption-q4_k_s.gguf", Q4_K_S_SHA256), ] last_exc = None for url, path, sha in candidates: try: download_file(url, path, expected_sha256=sha) print(f"Attempting to load GGUF: {path}") # lazy import so we catch import-time errors before rebuild attempt llama_cpp = importlib.import_module("llama_cpp") Llama = getattr(llama_cpp, "Llama") # minimal params; adjust n_ctx or gpu settings if available lm = Llama(model_path=str(path), n_ctx=2048, n_gpu_layers=0, verbose=False) print("Model loaded successfully.") return lm except Exception as e: print(f"Loading {path.name} failed: {e}") last_exc = e # If both failed, attempt a rebuild then retry first candidate once try: print("Both GGUF variants failed to load. Rebuilding llama-cpp-python from source...") rebuild_llama_cpp() except Exception as e: print(f"Rebuild failed: {e}") raise last_exc or e # After rebuild, import & load primary model try: import importlib llama_cpp = importlib.reload(importlib.import_module("llama_cpp")) Llama = getattr(llama_cpp, "Llama") path = candidates[0][1] if not path.is_file(): download_file(candidates[0][0], path, expected_sha256=candidates[0][2]) lm = Llama(model_path=str(path), n_ctx=2048, n_gpu_layers=0, verbose=False) print("Model loaded successfully after rebuild.") return lm except Exception as e: print(f"Load after rebuild failed: {e}") raise e # ---------------------------------------------------------------------- # Processor and model wrapper # ---------------------------------------------------------------------- # We keep AutoProcessor to reuse the chat template behaviour you used previously. processor = AutoProcessor.from_pretrained( HF_PROCESSOR_NAME, trust_remote_code=True, num_additional_image_tokens=1, **({} if not HF_TOKEN else {"token": HF_TOKEN}), ) # Lazy model holder class ModelWrapper: def __init__(self): self.llm = None # llama-cpp Llama instance def ensure_model(self): if self.llm is None: self.llm = try_load_gguf() def generate(self, prompt: str, max_new_tokens: int = MAX_NEW_TOKENS): self.ensure_model() # llama-cpp-python call style: model(prompt=..., max_tokens=..., temperature=..., top_p=..., stop=...) out = self.llm(prompt, max_tokens=max_new_tokens, temperature=TEMPERATURE, top_p=TOP_P, stop=STOP_STRS) # llama-cpp-python responses usually in out["choices"][0]["text"] txt = out.get("choices", [{}])[0].get("text", "") return txt MODEL = ModelWrapper() # ---------------------------------------------------------------------- # Inference: convert URL->image, build prompt via processor chat template, run llama-cpp # ---------------------------------------------------------------------- def generate_caption_from_url(url: str, prompt: str = "Describe the image.") -> str: if not url: return "No URL provided." try: raw = download_bytes(url) except Exception as e: return f"Download error: {e}" lower = url.lower().split("?")[0] try: if lower.endswith(".mp4") or raw[:16].lower().find(b"ftyp") != -1: try: raw = mp4_to_gif(raw) except Exception as e: return f"MP4→GIF conversion failed: {e}" img = load_first_frame_from_bytes(raw) except Exception as e: return f"Image processing error: {e}" # Resize to a conservative size (512) expected by many VLMs try: img = img.resize((512, 512), resample=Image.BICUBIC) except Exception: pass try: # Produce conversation so the processor inserts image token correctly conversation = [ {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]} ] inputs = processor.apply_chat_template( conversation, add_generation_prompt=True, return_tensors="pt", return_dict=True, images=img, ) # The processor provides a textual input (input_ids). We'll decode it to a plain prompt # string to feed llama-cpp. The processor has a `decode` helper; else we build a simple prompt. # Use processor.tokenizer if available to decode input_ids -> text. text_prompt = None if hasattr(processor, "tokenizer") and getattr(inputs, "input_ids", None) is not None: try: # inputs may be dict tensors; extract CPU numpy/torch then decode input_ids = inputs["input_ids"][0] # convert to list of ints if tensor import torch if hasattr(input_ids, "cpu"): ids = input_ids.cpu().numpy().tolist() else: ids = list(input_ids) text_prompt = processor.tokenizer.decode(ids, skip_special_tokens=True) except Exception: text_prompt = None if not text_prompt: # Fallback: simple textual template with a tag where the image is referenced. text_prompt = f" [image here] \n{prompt}\nAnswer:" # Debug prints (Space logs) print("Prompt to model (truncated):", text_prompt[:512].replace("\n", "\\n")) out_text = MODEL.generate(text_prompt, max_new_tokens=MAX_NEW_TOKENS) # Postprocess: strip, remove accidental stop tokens, etc. return out_text.strip() except Exception as e: return f"Inference error: {e}" # ---------------------------------------------------------------------- # Gradio UI (URL + prompt -> text) # ---------------------------------------------------------------------- gradio_kwargs = dict( fn=generate_caption_from_url, inputs=[ gr.Textbox(label="Image / GIF / MP4 URL", placeholder="https://example.com/photo.jpg"), gr.Textbox(label="Prompt (optional)", value="Describe the image."), ], outputs=gr.Textbox(label="Generated caption"), title="JoyCaption - URL input (GGUF + auto-rebuild)", description="Paste a direct link to an image/GIF/MP4 (MP4 will be converted).", ) try: iface = gr.Interface(**gradio_kwargs, allow_flagging="never") except TypeError: iface = gr.Interface(**gradio_kwargs) if __name__ == "__main__": try: iface.launch(server_name="0.0.0.0", server_port=7860) finally: try: import asyncio loop = asyncio.get_event_loop() if not loop.is_closed(): loop.close() except Exception: pass