Spaces:
Build error
Build error
File size: 4,889 Bytes
7766a5c 028a367 7aed240 7766a5c f275d7c 7766a5c 09c7c56 028a367 71b45b9 7766a5c 028a367 851e8b5 7766a5c 851e8b5 7766a5c 851e8b5 028a367 851e8b5 7766a5c 028a367 851e8b5 7766a5c 851e8b5 028a367 7aed240 028a367 35d219a 028a367 7aed240 028a367 7766a5c 09c7c56 7766a5c 35d219a 7766a5c 028a367 35d219a 028a367 7766a5c 71b45b9 09c7c56 851e8b5 71b45b9 7766a5c 851e8b5 71b45b9 028a367 71b45b9 09c7c56 71b45b9 028a367 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | import os
import io
import re
import torch
import requests
from PIL import Image, ImageSequence
from transformers import AutoProcessor, LlavaForConditionalGeneration
import gradio as gr
MODEL_NAME = "fancyfeast/llama-joycaption-beta-one-hf-llava"
HF_TOKEN = os.getenv("HF_TOKEN") # optional
def download_bytes(url: str, timeout: int = 30) -> bytes:
with requests.get(url, stream=True, timeout=timeout) as resp:
resp.raise_for_status()
return resp.content
def mp4_to_gif(mp4_bytes: bytes) -> bytes:
files = {"new-file": ("video.mp4", mp4_bytes, "video/mp4")}
resp = requests.post(
"https://s.ezgif.com/video-to-gif",
files=files,
data={"file": "video.mp4"},
timeout=120,
)
resp.raise_for_status()
match = re.search(r'<img[^>]+src="([^"]+\.gif)"', resp.text)
if not match:
match = re.search(r'src="([^"]+?/tmp/[^"]+\.gif)"', resp.text)
if not match:
raise RuntimeError("Failed to extract GIF URL from ezgif response")
gif_url = match.group(1)
if gif_url.startswith("//"):
gif_url = "https:" + gif_url
elif gif_url.startswith("/"):
gif_url = "https://s.ezgif.com" + gif_url
with requests.get(gif_url, timeout=60) as gif_resp:
gif_resp.raise_for_status()
return gif_resp.content
def load_first_frame_from_bytes(raw: bytes) -> Image.Image:
img = Image.open(io.BytesIO(raw))
if getattr(img, "is_animated", False):
img = next(ImageSequence.Iterator(img))
if img.mode != "RGB":
img = img.convert("RGB")
return img
# Load processor + model
token_arg = {"use_auth_token": HF_TOKEN} if HF_TOKEN else {}
processor = AutoProcessor.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
num_additional_image_tokens=1,
**({} if not HF_TOKEN else {"token": HF_TOKEN})
)
# CPU Space -> use float32
llava_model = LlavaForConditionalGeneration.from_pretrained(
MODEL_NAME,
device_map="cpu",
torch_dtype=torch.float32,
trust_remote_code=True,
**({} if not HF_TOKEN else {"token": HF_TOKEN})
)
llava_model.eval()
def generate_caption_from_url(url: str, prompt: str = "Describe the image.") -> str:
if not url:
return "No URL provided."
try:
raw = download_bytes(url)
except Exception as e:
return f"Download error: {e}"
lower = url.lower().split("?")[0]
try:
if lower.endswith(".mp4") or raw[:16].lower().find(b"ftyp") != -1:
try:
raw = mp4_to_gif(raw)
except Exception as e:
return f"MP4→GIF conversion failed: {e}"
img = load_first_frame_from_bytes(raw)
except Exception as e:
return f"Image processing error: {e}"
# Resize to conservative default
try:
img = img.resize((512, 512), resample=Image.BICUBIC)
except Exception:
pass
try:
conversation = [
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}
]
inputs = processor.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
images=img,
)
device = llava_model.device
inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
if "pixel_values" in inputs:
inputs["pixel_values"] = inputs["pixel_values"].to(dtype=llava_model.dtype, device=device)
# Minimal debug info (appears in Space logs)
if "pixel_values" in inputs:
print("pixel_values.shape:", inputs["pixel_values"].shape)
if "input_ids" in inputs:
print("input_ids.shape:", inputs["input_ids"].shape)
with torch.no_grad():
out_ids = llava_model.generate(**inputs, max_new_tokens=128)
caption = processor.decode(out_ids[0], skip_special_tokens=True)
return caption
except Exception as e:
return f"Inference error: {e}"
gradio_kwargs = dict(
fn=generate_caption_from_url,
inputs=[
gr.Textbox(label="Image / GIF / MP4 URL", placeholder="https://example.com/photo.jpg"),
gr.Textbox(label="Prompt (optional)", value="Describe the image."),
],
outputs=gr.Textbox(label="Generated caption"),
title="JoyCaption - URL input",
description="Paste a direct link to an image/GIF/MP4 (MP4 will be converted).",
)
try:
iface = gr.Interface(**gradio_kwargs, allow_flagging="never")
except TypeError:
iface = gr.Interface(**gradio_kwargs)
if __name__ == "__main__":
try:
iface.launch(server_name="0.0.0.0", server_port=7860)
finally:
try:
import asyncio
loop = asyncio.get_event_loop()
if not loop.is_closed():
loop.close()
except Exception:
pass
|