Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| import os | |
| import io | |
| import re | |
| import sys | |
| import time | |
| import hashlib | |
| import pathlib | |
| import subprocess | |
| from typing import Optional | |
| import requests | |
| from PIL import Image, ImageSequence | |
| import gradio as gr | |
| # If you still want to use HF AutoProcessor / LlavaForConditionalGeneration for decoding, | |
| # keep transformers installed and uncomment the imports below. This file instead uses | |
| # llama-cpp-python for model inference (GGUF). | |
| from transformers import AutoProcessor | |
| # ---------------------------------------------------------------------- | |
| # Config: set model URLs and optional checksums | |
| # ---------------------------------------------------------------------- | |
| MODEL_DIR = pathlib.Path("model") | |
| MODEL_DIR.mkdir(parents=True, exist_ok=True) | |
| # Replace these with your preferred GGUF files (mradermacher or TheBloke variants) | |
| Q4_K_M_URL = ( | |
| "https://huggingface.co/mradermacher/joycaption-llama/resolve/main/llama-joycaption-q4_k_m.gguf" | |
| ) | |
| Q4_K_S_URL = ( | |
| "https://huggingface.co/mradermacher/joycaption-llama/resolve/main/llama-joycaption-q4_k_s.gguf" | |
| ) | |
| # Optional: set SHA256 checksums to validate downloads (replace with real values) | |
| Q4_K_M_SHA256: Optional[str] = None | |
| Q4_K_S_SHA256: Optional[str] = None | |
| # Generation params | |
| MAX_NEW_TOKENS = 128 | |
| TEMPERATURE = 0.2 | |
| TOP_P = 0.95 | |
| STOP_STRS = ["\n"] | |
| # HF processor/model name used previously for tokenization/chat template | |
| HF_PROCESSOR_NAME = "fancyfeast/llama-joycaption-beta-one-hf-llava" | |
| HF_TOKEN = os.getenv("HF_TOKEN") # optional | |
| # ---------------------------------------------------------------------- | |
| # Utilities: downloads, checksum, mp4->gif, image load | |
| # ---------------------------------------------------------------------- | |
| def download_bytes(url: str, timeout: int = 30) -> bytes: | |
| with requests.get(url, stream=True, timeout=timeout) as resp: | |
| resp.raise_for_status() | |
| return resp.content | |
| def mp4_to_gif(mp4_bytes: bytes) -> bytes: | |
| files = {"new-file": ("video.mp4", mp4_bytes, "video/mp4")} | |
| resp = requests.post( | |
| "https://s.ezgif.com/video-to-gif", | |
| files=files, | |
| data={"file": "video.mp4"}, | |
| timeout=120, | |
| ) | |
| resp.raise_for_status() | |
| match = re.search(r'<img[^>]+src="([^"]+\.gif)"', resp.text) | |
| if not match: | |
| match = re.search(r'src="([^"]+?/tmp/[^"]+\.gif)"', resp.text) | |
| if not match: | |
| raise RuntimeError("Failed to extract GIF URL from ezgif response") | |
| gif_url = match.group(1) | |
| if gif_url.startswith("//"): | |
| gif_url = "https:" + gif_url | |
| elif gif_url.startswith("/"): | |
| gif_url = "https://s.ezgif.com" + gif_url | |
| with requests.get(gif_url, timeout=60) as gif_resp: | |
| gif_resp.raise_for_status() | |
| return gif_resp.content | |
| def load_first_frame_from_bytes(raw: bytes) -> Image.Image: | |
| img = Image.open(io.BytesIO(raw)) | |
| if getattr(img, "is_animated", False): | |
| img = next(ImageSequence.Iterator(img)) | |
| if img.mode != "RGB": | |
| img = img.convert("RGB") | |
| return img | |
| def sha256_of_file(path: pathlib.Path) -> str: | |
| h = hashlib.sha256() | |
| with open(path, "rb") as f: | |
| for block in iter(lambda: f.read(65536), b""): | |
| h.update(block) | |
| return h.hexdigest() | |
| def download_file(url: str, dest: pathlib.Path, expected_sha256: Optional[str] = None) -> None: | |
| if dest.is_file(): | |
| if expected_sha256: | |
| try: | |
| if sha256_of_file(dest) == expected_sha256: | |
| return | |
| except Exception: | |
| pass | |
| # remove possibly corrupted/old file | |
| dest.unlink() | |
| print(f"Downloading model from {url} -> {dest}") | |
| with requests.get(url, stream=True, timeout=120) as r: | |
| r.raise_for_status() | |
| total = int(r.headers.get("content-length", 0) or 0) | |
| downloaded = 0 | |
| with open(dest, "wb") as f: | |
| for chunk in r.iter_content(chunk_size=8192): | |
| if not chunk: | |
| continue | |
| f.write(chunk) | |
| downloaded += len(chunk) | |
| if total: | |
| pct = downloaded * 100 // total | |
| print(f"\r{dest.name}: {pct}% ", end="", flush=True) | |
| print() | |
| if expected_sha256: | |
| got = sha256_of_file(dest) | |
| if got != expected_sha256: | |
| raise ValueError(f"Checksum mismatch for {dest}: got {got}, expected {expected_sha256}") | |
| # ---------------------------------------------------------------------- | |
| # llama-cpp loading + automated rebuild | |
| # ---------------------------------------------------------------------- | |
| def rebuild_llama_cpp() -> None: | |
| env = os.environ.copy() | |
| env["PIP_NO_BINARY"] = "llama-cpp-python" | |
| # upgrade pip then reinstall | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "pip"], env=env) | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "cmake", "wheel", "setuptools"], env=env) | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "llama-cpp-python"], env=env) | |
| def try_load_gguf() -> "llama_cpp.Llama": | |
| """ | |
| Download Q4_K_M then Q4_K_S and attempt to load with llama_cpp.Llama. | |
| If both fail, rebuild llama-cpp-python from source and retry once. | |
| """ | |
| import importlib | |
| from pathlib import Path | |
| candidates = [ | |
| (Q4_K_M_URL, MODEL_DIR / "llama-joycaption-q4_k_m.gguf", Q4_K_M_SHA256), | |
| (Q4_K_S_URL, MODEL_DIR / "llama-joycaption-q4_k_s.gguf", Q4_K_S_SHA256), | |
| ] | |
| last_exc = None | |
| for url, path, sha in candidates: | |
| try: | |
| download_file(url, path, expected_sha256=sha) | |
| print(f"Attempting to load GGUF: {path}") | |
| # lazy import so we catch import-time errors before rebuild attempt | |
| llama_cpp = importlib.import_module("llama_cpp") | |
| Llama = getattr(llama_cpp, "Llama") | |
| # minimal params; adjust n_ctx or gpu settings if available | |
| lm = Llama(model_path=str(path), n_ctx=2048, n_gpu_layers=0, verbose=False) | |
| print("Model loaded successfully.") | |
| return lm | |
| except Exception as e: | |
| print(f"Loading {path.name} failed: {e}") | |
| last_exc = e | |
| # If both failed, attempt a rebuild then retry first candidate once | |
| try: | |
| print("Both GGUF variants failed to load. Rebuilding llama-cpp-python from source...") | |
| rebuild_llama_cpp() | |
| except Exception as e: | |
| print(f"Rebuild failed: {e}") | |
| raise last_exc or e | |
| # After rebuild, import & load primary model | |
| try: | |
| import importlib | |
| llama_cpp = importlib.reload(importlib.import_module("llama_cpp")) | |
| Llama = getattr(llama_cpp, "Llama") | |
| path = candidates[0][1] | |
| if not path.is_file(): | |
| download_file(candidates[0][0], path, expected_sha256=candidates[0][2]) | |
| lm = Llama(model_path=str(path), n_ctx=2048, n_gpu_layers=0, verbose=False) | |
| print("Model loaded successfully after rebuild.") | |
| return lm | |
| except Exception as e: | |
| print(f"Load after rebuild failed: {e}") | |
| raise e | |
| # ---------------------------------------------------------------------- | |
| # Processor and model wrapper | |
| # ---------------------------------------------------------------------- | |
| # We keep AutoProcessor to reuse the chat template behaviour you used previously. | |
| processor = AutoProcessor.from_pretrained( | |
| HF_PROCESSOR_NAME, | |
| trust_remote_code=True, | |
| num_additional_image_tokens=1, | |
| **({} if not HF_TOKEN else {"token": HF_TOKEN}), | |
| ) | |
| # Lazy model holder | |
| class ModelWrapper: | |
| def __init__(self): | |
| self.llm = None # llama-cpp Llama instance | |
| def ensure_model(self): | |
| if self.llm is None: | |
| self.llm = try_load_gguf() | |
| def generate(self, prompt: str, max_new_tokens: int = MAX_NEW_TOKENS): | |
| self.ensure_model() | |
| # llama-cpp-python call style: model(prompt=..., max_tokens=..., temperature=..., top_p=..., stop=...) | |
| out = self.llm(prompt, max_tokens=max_new_tokens, temperature=TEMPERATURE, top_p=TOP_P, stop=STOP_STRS) | |
| # llama-cpp-python responses usually in out["choices"][0]["text"] | |
| txt = out.get("choices", [{}])[0].get("text", "") | |
| return txt | |
| MODEL = ModelWrapper() | |
| # ---------------------------------------------------------------------- | |
| # Inference: convert URL->image, build prompt via processor chat template, run llama-cpp | |
| # ---------------------------------------------------------------------- | |
| def generate_caption_from_url(url: str, prompt: str = "Describe the image.") -> str: | |
| if not url: | |
| return "No URL provided." | |
| try: | |
| raw = download_bytes(url) | |
| except Exception as e: | |
| return f"Download error: {e}" | |
| lower = url.lower().split("?")[0] | |
| try: | |
| if lower.endswith(".mp4") or raw[:16].lower().find(b"ftyp") != -1: | |
| try: | |
| raw = mp4_to_gif(raw) | |
| except Exception as e: | |
| return f"MP4→GIF conversion failed: {e}" | |
| img = load_first_frame_from_bytes(raw) | |
| except Exception as e: | |
| return f"Image processing error: {e}" | |
| # Resize to a conservative size (512) expected by many VLMs | |
| try: | |
| img = img.resize((512, 512), resample=Image.BICUBIC) | |
| except Exception: | |
| pass | |
| try: | |
| # Produce conversation so the processor inserts image token correctly | |
| conversation = [ | |
| {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]} | |
| ] | |
| inputs = processor.apply_chat_template( | |
| conversation, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| return_dict=True, | |
| images=img, | |
| ) | |
| # The processor provides a textual input (input_ids). We'll decode it to a plain prompt | |
| # string to feed llama-cpp. The processor has a `decode` helper; else we build a simple prompt. | |
| # Use processor.tokenizer if available to decode input_ids -> text. | |
| text_prompt = None | |
| if hasattr(processor, "tokenizer") and getattr(inputs, "input_ids", None) is not None: | |
| try: | |
| # inputs may be dict tensors; extract CPU numpy/torch then decode | |
| input_ids = inputs["input_ids"][0] | |
| # convert to list of ints if tensor | |
| import torch | |
| if hasattr(input_ids, "cpu"): | |
| ids = input_ids.cpu().numpy().tolist() | |
| else: | |
| ids = list(input_ids) | |
| text_prompt = processor.tokenizer.decode(ids, skip_special_tokens=True) | |
| except Exception: | |
| text_prompt = None | |
| if not text_prompt: | |
| # Fallback: simple textual template with a tag where the image is referenced. | |
| text_prompt = f"<img> [image here] </img>\n{prompt}\nAnswer:" | |
| # Debug prints (Space logs) | |
| print("Prompt to model (truncated):", text_prompt[:512].replace("\n", "\\n")) | |
| out_text = MODEL.generate(text_prompt, max_new_tokens=MAX_NEW_TOKENS) | |
| # Postprocess: strip, remove accidental stop tokens, etc. | |
| return out_text.strip() | |
| except Exception as e: | |
| return f"Inference error: {e}" | |
| # ---------------------------------------------------------------------- | |
| # Gradio UI (URL + prompt -> text) | |
| # ---------------------------------------------------------------------- | |
| gradio_kwargs = dict( | |
| fn=generate_caption_from_url, | |
| inputs=[ | |
| gr.Textbox(label="Image / GIF / MP4 URL", placeholder="https://example.com/photo.jpg"), | |
| gr.Textbox(label="Prompt (optional)", value="Describe the image."), | |
| ], | |
| outputs=gr.Textbox(label="Generated caption"), | |
| title="JoyCaption - URL input (GGUF + auto-rebuild)", | |
| description="Paste a direct link to an image/GIF/MP4 (MP4 will be converted).", | |
| ) | |
| try: | |
| iface = gr.Interface(**gradio_kwargs, allow_flagging="never") | |
| except TypeError: | |
| iface = gr.Interface(**gradio_kwargs) | |
| if __name__ == "__main__": | |
| try: | |
| iface.launch(server_name="0.0.0.0", server_port=7860) | |
| finally: | |
| try: | |
| import asyncio | |
| loop = asyncio.get_event_loop() | |
| if not loop.is_closed(): | |
| loop.close() | |
| except Exception: | |
| pass | |