Hug0endob's picture
Update app.py
cd5ca02 verified
raw
history blame
12.4 kB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import io
import re
import sys
import time
import hashlib
import pathlib
import subprocess
from typing import Optional
import requests
from PIL import Image, ImageSequence
import gradio as gr
# If you still want to use HF AutoProcessor / LlavaForConditionalGeneration for decoding,
# keep transformers installed and uncomment the imports below. This file instead uses
# llama-cpp-python for model inference (GGUF).
from transformers import AutoProcessor
# ----------------------------------------------------------------------
# Config: set model URLs and optional checksums
# ----------------------------------------------------------------------
MODEL_DIR = pathlib.Path("model")
MODEL_DIR.mkdir(parents=True, exist_ok=True)
# Replace these with your preferred GGUF files (mradermacher or TheBloke variants)
Q4_K_M_URL = (
"https://huggingface.co/mradermacher/joycaption-llama/resolve/main/llama-joycaption-q4_k_m.gguf"
)
Q4_K_S_URL = (
"https://huggingface.co/mradermacher/joycaption-llama/resolve/main/llama-joycaption-q4_k_s.gguf"
)
# Optional: set SHA256 checksums to validate downloads (replace with real values)
Q4_K_M_SHA256: Optional[str] = None
Q4_K_S_SHA256: Optional[str] = None
# Generation params
MAX_NEW_TOKENS = 128
TEMPERATURE = 0.2
TOP_P = 0.95
STOP_STRS = ["\n"]
# HF processor/model name used previously for tokenization/chat template
HF_PROCESSOR_NAME = "fancyfeast/llama-joycaption-beta-one-hf-llava"
HF_TOKEN = os.getenv("HF_TOKEN") # optional
# ----------------------------------------------------------------------
# Utilities: downloads, checksum, mp4->gif, image load
# ----------------------------------------------------------------------
def download_bytes(url: str, timeout: int = 30) -> bytes:
with requests.get(url, stream=True, timeout=timeout) as resp:
resp.raise_for_status()
return resp.content
def mp4_to_gif(mp4_bytes: bytes) -> bytes:
files = {"new-file": ("video.mp4", mp4_bytes, "video/mp4")}
resp = requests.post(
"https://s.ezgif.com/video-to-gif",
files=files,
data={"file": "video.mp4"},
timeout=120,
)
resp.raise_for_status()
match = re.search(r'<img[^>]+src="([^"]+\.gif)"', resp.text)
if not match:
match = re.search(r'src="([^"]+?/tmp/[^"]+\.gif)"', resp.text)
if not match:
raise RuntimeError("Failed to extract GIF URL from ezgif response")
gif_url = match.group(1)
if gif_url.startswith("//"):
gif_url = "https:" + gif_url
elif gif_url.startswith("/"):
gif_url = "https://s.ezgif.com" + gif_url
with requests.get(gif_url, timeout=60) as gif_resp:
gif_resp.raise_for_status()
return gif_resp.content
def load_first_frame_from_bytes(raw: bytes) -> Image.Image:
img = Image.open(io.BytesIO(raw))
if getattr(img, "is_animated", False):
img = next(ImageSequence.Iterator(img))
if img.mode != "RGB":
img = img.convert("RGB")
return img
def sha256_of_file(path: pathlib.Path) -> str:
h = hashlib.sha256()
with open(path, "rb") as f:
for block in iter(lambda: f.read(65536), b""):
h.update(block)
return h.hexdigest()
def download_file(url: str, dest: pathlib.Path, expected_sha256: Optional[str] = None) -> None:
if dest.is_file():
if expected_sha256:
try:
if sha256_of_file(dest) == expected_sha256:
return
except Exception:
pass
# remove possibly corrupted/old file
dest.unlink()
print(f"Downloading model from {url} -> {dest}")
with requests.get(url, stream=True, timeout=120) as r:
r.raise_for_status()
total = int(r.headers.get("content-length", 0) or 0)
downloaded = 0
with open(dest, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
if not chunk:
continue
f.write(chunk)
downloaded += len(chunk)
if total:
pct = downloaded * 100 // total
print(f"\r{dest.name}: {pct}% ", end="", flush=True)
print()
if expected_sha256:
got = sha256_of_file(dest)
if got != expected_sha256:
raise ValueError(f"Checksum mismatch for {dest}: got {got}, expected {expected_sha256}")
# ----------------------------------------------------------------------
# llama-cpp loading + automated rebuild
# ----------------------------------------------------------------------
def rebuild_llama_cpp() -> None:
env = os.environ.copy()
env["PIP_NO_BINARY"] = "llama-cpp-python"
# upgrade pip then reinstall
subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "pip"], env=env)
subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "cmake", "wheel", "setuptools"], env=env)
subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "llama-cpp-python"], env=env)
def try_load_gguf() -> "llama_cpp.Llama":
"""
Download Q4_K_M then Q4_K_S and attempt to load with llama_cpp.Llama.
If both fail, rebuild llama-cpp-python from source and retry once.
"""
import importlib
from pathlib import Path
candidates = [
(Q4_K_M_URL, MODEL_DIR / "llama-joycaption-q4_k_m.gguf", Q4_K_M_SHA256),
(Q4_K_S_URL, MODEL_DIR / "llama-joycaption-q4_k_s.gguf", Q4_K_S_SHA256),
]
last_exc = None
for url, path, sha in candidates:
try:
download_file(url, path, expected_sha256=sha)
print(f"Attempting to load GGUF: {path}")
# lazy import so we catch import-time errors before rebuild attempt
llama_cpp = importlib.import_module("llama_cpp")
Llama = getattr(llama_cpp, "Llama")
# minimal params; adjust n_ctx or gpu settings if available
lm = Llama(model_path=str(path), n_ctx=2048, n_gpu_layers=0, verbose=False)
print("Model loaded successfully.")
return lm
except Exception as e:
print(f"Loading {path.name} failed: {e}")
last_exc = e
# If both failed, attempt a rebuild then retry first candidate once
try:
print("Both GGUF variants failed to load. Rebuilding llama-cpp-python from source...")
rebuild_llama_cpp()
except Exception as e:
print(f"Rebuild failed: {e}")
raise last_exc or e
# After rebuild, import & load primary model
try:
import importlib
llama_cpp = importlib.reload(importlib.import_module("llama_cpp"))
Llama = getattr(llama_cpp, "Llama")
path = candidates[0][1]
if not path.is_file():
download_file(candidates[0][0], path, expected_sha256=candidates[0][2])
lm = Llama(model_path=str(path), n_ctx=2048, n_gpu_layers=0, verbose=False)
print("Model loaded successfully after rebuild.")
return lm
except Exception as e:
print(f"Load after rebuild failed: {e}")
raise e
# ----------------------------------------------------------------------
# Processor and model wrapper
# ----------------------------------------------------------------------
# We keep AutoProcessor to reuse the chat template behaviour you used previously.
processor = AutoProcessor.from_pretrained(
HF_PROCESSOR_NAME,
trust_remote_code=True,
num_additional_image_tokens=1,
**({} if not HF_TOKEN else {"token": HF_TOKEN}),
)
# Lazy model holder
class ModelWrapper:
def __init__(self):
self.llm = None # llama-cpp Llama instance
def ensure_model(self):
if self.llm is None:
self.llm = try_load_gguf()
def generate(self, prompt: str, max_new_tokens: int = MAX_NEW_TOKENS):
self.ensure_model()
# llama-cpp-python call style: model(prompt=..., max_tokens=..., temperature=..., top_p=..., stop=...)
out = self.llm(prompt, max_tokens=max_new_tokens, temperature=TEMPERATURE, top_p=TOP_P, stop=STOP_STRS)
# llama-cpp-python responses usually in out["choices"][0]["text"]
txt = out.get("choices", [{}])[0].get("text", "")
return txt
MODEL = ModelWrapper()
# ----------------------------------------------------------------------
# Inference: convert URL->image, build prompt via processor chat template, run llama-cpp
# ----------------------------------------------------------------------
def generate_caption_from_url(url: str, prompt: str = "Describe the image.") -> str:
if not url:
return "No URL provided."
try:
raw = download_bytes(url)
except Exception as e:
return f"Download error: {e}"
lower = url.lower().split("?")[0]
try:
if lower.endswith(".mp4") or raw[:16].lower().find(b"ftyp") != -1:
try:
raw = mp4_to_gif(raw)
except Exception as e:
return f"MP4→GIF conversion failed: {e}"
img = load_first_frame_from_bytes(raw)
except Exception as e:
return f"Image processing error: {e}"
# Resize to a conservative size (512) expected by many VLMs
try:
img = img.resize((512, 512), resample=Image.BICUBIC)
except Exception:
pass
try:
# Produce conversation so the processor inserts image token correctly
conversation = [
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}
]
inputs = processor.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
images=img,
)
# The processor provides a textual input (input_ids). We'll decode it to a plain prompt
# string to feed llama-cpp. The processor has a `decode` helper; else we build a simple prompt.
# Use processor.tokenizer if available to decode input_ids -> text.
text_prompt = None
if hasattr(processor, "tokenizer") and getattr(inputs, "input_ids", None) is not None:
try:
# inputs may be dict tensors; extract CPU numpy/torch then decode
input_ids = inputs["input_ids"][0]
# convert to list of ints if tensor
import torch
if hasattr(input_ids, "cpu"):
ids = input_ids.cpu().numpy().tolist()
else:
ids = list(input_ids)
text_prompt = processor.tokenizer.decode(ids, skip_special_tokens=True)
except Exception:
text_prompt = None
if not text_prompt:
# Fallback: simple textual template with a tag where the image is referenced.
text_prompt = f"<img> [image here] </img>\n{prompt}\nAnswer:"
# Debug prints (Space logs)
print("Prompt to model (truncated):", text_prompt[:512].replace("\n", "\\n"))
out_text = MODEL.generate(text_prompt, max_new_tokens=MAX_NEW_TOKENS)
# Postprocess: strip, remove accidental stop tokens, etc.
return out_text.strip()
except Exception as e:
return f"Inference error: {e}"
# ----------------------------------------------------------------------
# Gradio UI (URL + prompt -> text)
# ----------------------------------------------------------------------
gradio_kwargs = dict(
fn=generate_caption_from_url,
inputs=[
gr.Textbox(label="Image / GIF / MP4 URL", placeholder="https://example.com/photo.jpg"),
gr.Textbox(label="Prompt (optional)", value="Describe the image."),
],
outputs=gr.Textbox(label="Generated caption"),
title="JoyCaption - URL input (GGUF + auto-rebuild)",
description="Paste a direct link to an image/GIF/MP4 (MP4 will be converted).",
)
try:
iface = gr.Interface(**gradio_kwargs, allow_flagging="never")
except TypeError:
iface = gr.Interface(**gradio_kwargs)
if __name__ == "__main__":
try:
iface.launch(server_name="0.0.0.0", server_port=7860)
finally:
try:
import asyncio
loop = asyncio.get_event_loop()
if not loop.is_closed():
loop.close()
except Exception:
pass