Hug0endob commited on
Commit
49d3ba7
·
verified ·
1 Parent(s): 10d27da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -100
app.py CHANGED
@@ -1,44 +1,28 @@
1
  import os
2
  import io
3
- import re
4
- import torch
5
  import requests
6
  from PIL import Image, ImageSequence
7
- from transformers import AutoProcessor, LlavaForConditionalGeneration
8
  import gradio as gr
9
 
10
- MODEL_NAME = "fancyfeast/llama-joycaption-beta-one-hf-llava"
11
- HF_TOKEN = os.getenv("HF_TOKEN") # optional
 
 
 
 
 
 
 
12
 
 
13
  def download_bytes(url: str, timeout: int = 30) -> bytes:
14
  with requests.get(url, stream=True, timeout=timeout) as resp:
15
  resp.raise_for_status()
16
  return resp.content
17
 
18
- def mp4_to_gif(mp4_bytes: bytes) -> bytes:
19
- files = {"new-file": ("video.mp4", mp4_bytes, "video/mp4")}
20
- resp = requests.post(
21
- "https://s.ezgif.com/video-to-gif",
22
- files=files,
23
- data={"file": "video.mp4"},
24
- timeout=120,
25
- )
26
- resp.raise_for_status()
27
- match = re.search(r'<img[^>]+src="([^"]+\.gif)"', resp.text)
28
- if not match:
29
- match = re.search(r'src="([^"]+?/tmp/[^"]+\.gif)"', resp.text)
30
- if not match:
31
- raise RuntimeError("Failed to extract GIF URL from ezgif response")
32
- gif_url = match.group(1)
33
- if gif_url.startswith("//"):
34
- gif_url = "https:" + gif_url
35
- elif gif_url.startswith("/"):
36
- gif_url = "https://s.ezgif.com" + gif_url
37
- with requests.get(gif_url, timeout=60) as gif_resp:
38
- gif_resp.raise_for_status()
39
- return gif_resp.content
40
-
41
- def load_first_frame_from_bytes(raw: bytes) -> Image.Image:
42
  img = Image.open(io.BytesIO(raw))
43
  if getattr(img, "is_animated", False):
44
  img = next(ImageSequence.Iterator(img))
@@ -46,102 +30,69 @@ def load_first_frame_from_bytes(raw: bytes) -> Image.Image:
46
  img = img.convert("RGB")
47
  return img
48
 
49
- # Load processor + model
50
- token_arg = {"use_auth_token": HF_TOKEN} if HF_TOKEN else {}
51
- processor = AutoProcessor.from_pretrained(
52
- MODEL_NAME,
53
- trust_remote_code=True,
54
- num_additional_image_tokens=1,
55
- **({} if not HF_TOKEN else {"token": HF_TOKEN})
56
- )
57
- # CPU Space -> use float32
58
- llava_model = LlavaForConditionalGeneration.from_pretrained(
59
- MODEL_NAME,
60
- device_map="cpu",
61
- torch_dtype=torch.float32,
62
- trust_remote_code=True,
63
- **({} if not HF_TOKEN else {"token": HF_TOKEN})
64
- )
65
- llava_model.eval()
66
 
67
- def generate_caption_from_url(url: str, prompt: str = "Describe the image.") -> str:
68
  if not url:
69
  return "No URL provided."
70
  try:
71
  raw = download_bytes(url)
72
  except Exception as e:
73
  return f"Download error: {e}"
74
-
75
- lower = url.lower().split("?")[0]
76
  try:
77
- if lower.endswith(".mp4") or raw[:16].lower().find(b"ftyp") != -1:
78
- try:
79
- raw = mp4_to_gif(raw)
80
- except Exception as e:
81
- return f"MP4→GIF conversion failed: {e}"
82
  img = load_first_frame_from_bytes(raw)
83
  except Exception as e:
84
  return f"Image processing error: {e}"
85
 
86
- # Resize to conservative default
 
 
 
 
87
  try:
88
- img = img.resize((512, 512), resample=Image.BICUBIC)
89
- except Exception:
90
- pass
91
 
 
92
  try:
93
- conversation = [
94
- {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}
95
- ]
96
- inputs = processor.apply_chat_template(
97
- conversation,
98
- add_generation_prompt=True,
99
- return_tensors="pt",
100
- return_dict=True,
101
- images=img,
102
  )
103
- device = llava_model.device
104
- inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
105
- if "pixel_values" in inputs:
106
- inputs["pixel_values"] = inputs["pixel_values"].to(dtype=llava_model.dtype, device=device)
107
-
108
- # Minimal debug info (appears in Space logs)
109
- if "pixel_values" in inputs:
110
- print("pixel_values.shape:", inputs["pixel_values"].shape)
111
- if "input_ids" in inputs:
112
- print("input_ids.shape:", inputs["input_ids"].shape)
113
-
114
- with torch.no_grad():
115
- out_ids = llava_model.generate(**inputs, max_new_tokens=128)
116
- caption = processor.decode(out_ids[0], skip_special_tokens=True)
117
- return caption
118
  except Exception as e:
119
  return f"Inference error: {e}"
 
 
 
 
 
120
 
121
- gradio_kwargs = dict(
122
  fn=generate_caption_from_url,
123
  inputs=[
124
  gr.Textbox(label="Image / GIF / MP4 URL", placeholder="https://example.com/photo.jpg"),
125
  gr.Textbox(label="Prompt (optional)", value="Describe the image."),
126
  ],
127
  outputs=gr.Textbox(label="Generated caption"),
128
- title="JoyCaption - URL input",
129
- description="Paste a direct link to an image/GIF/MP4 (MP4 will be converted).",
130
  )
131
 
132
- try:
133
- iface = gr.Interface(**gradio_kwargs, allow_flagging="never")
134
- except TypeError:
135
- iface = gr.Interface(**gradio_kwargs)
136
-
137
  if __name__ == "__main__":
138
- try:
139
- iface.launch(server_name="0.0.0.0", server_port=7860)
140
- finally:
141
- try:
142
- import asyncio
143
- loop = asyncio.get_event_loop()
144
- if not loop.is_closed():
145
- loop.close()
146
- except Exception:
147
- pass
 
1
  import os
2
  import io
3
+ import sys
4
+ import time
5
  import requests
6
  from PIL import Image, ImageSequence
 
7
  import gradio as gr
8
 
9
+ # Try to import llama-cpp-python
10
+ try:
11
+ from llama_cpp import Llama
12
+ except Exception as e:
13
+ raise RuntimeError("llama-cpp-python import failed; ensure requirements installed and wheel built: " + str(e))
14
+
15
+ MODEL_PATH = os.path.join("model", "model.gguf") # start.sh places GGUF here
16
+ if not os.path.exists(MODEL_PATH):
17
+ raise FileNotFoundError(f"Model not found at {MODEL_PATH}. Set correct GGUF in start.sh and redeploy.")
18
 
19
+ # Helper: load first frame and convert to JPEG bytes
20
  def download_bytes(url: str, timeout: int = 30) -> bytes:
21
  with requests.get(url, stream=True, timeout=timeout) as resp:
22
  resp.raise_for_status()
23
  return resp.content
24
 
25
+ def load_first_frame_from_bytes(raw: bytes):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  img = Image.open(io.BytesIO(raw))
27
  if getattr(img, "is_animated", False):
28
  img = next(ImageSequence.Iterator(img))
 
30
  img = img.convert("RGB")
31
  return img
32
 
33
+ # Minimal image caption prompt template — adjust for your model's expected prompt
34
+ def make_prompt_for_image(image_path: str, user_prompt: str = "Describe the image."):
35
+ # Many llama.cpp-based multimodal ggufs accept: "<img>{path}</img>\nUser: {prompt}\nAssistant:"
36
+ # We'll use that pattern.
37
+ return f"<img>{image_path}</img>\nUser: {user_prompt}\nAssistant:"
38
+
39
+ # Start model (llama-cpp-python will mmap model and run inference)
40
+ # Use low-memory opts: n_ctx small, use_mlock=0, n_gpu_layers=0
41
+ print("Loading model (this may take a minute)...", file=sys.stderr)
42
+ llm = Llama(model_path=MODEL_PATH, n_ctx=2048, n_threads=2)
 
 
 
 
 
 
 
43
 
44
+ def generate_caption_from_url(url: str, prompt: str = "Describe the image."):
45
  if not url:
46
  return "No URL provided."
47
  try:
48
  raw = download_bytes(url)
49
  except Exception as e:
50
  return f"Download error: {e}"
 
 
51
  try:
 
 
 
 
 
52
  img = load_first_frame_from_bytes(raw)
53
  except Exception as e:
54
  return f"Image processing error: {e}"
55
 
56
+ # Save a temporary JPEG locally so the gguf image token loader can access it
57
+ tmp_dir = "/tmp/joycap"
58
+ os.makedirs(tmp_dir, exist_ok=True)
59
+ ts = int(time.time() * 1000)
60
+ tmp_path = os.path.join(tmp_dir, f"{ts}.jpg")
61
  try:
62
+ img.save(tmp_path, format="JPEG", quality=85)
63
+ except Exception as e:
64
+ return f"Failed to save temp image: {e}"
65
 
66
+ prompt_full = make_prompt_for_image(tmp_path, prompt)
67
  try:
68
+ # llama-cpp-python generate call
69
+ resp = llm.create(
70
+ prompt=prompt_full,
71
+ max_tokens=256,
72
+ temperature=0.2,
73
+ top_p=0.95,
74
+ stop=["User:", "Assistant:"],
 
 
75
  )
76
+ text = resp.get("choices", [{}])[0].get("text", "").strip()
77
+ return text or "No caption generated."
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  except Exception as e:
79
  return f"Inference error: {e}"
80
+ finally:
81
+ try:
82
+ os.remove(tmp_path)
83
+ except Exception:
84
+ pass
85
 
86
+ iface = gr.Interface(
87
  fn=generate_caption_from_url,
88
  inputs=[
89
  gr.Textbox(label="Image / GIF / MP4 URL", placeholder="https://example.com/photo.jpg"),
90
  gr.Textbox(label="Prompt (optional)", value="Describe the image."),
91
  ],
92
  outputs=gr.Textbox(label="Generated caption"),
93
+ title="JoyCaption - local GGUF (Q4)",
94
+ description="Runs a quantized GGUF model locally via llama.cpp (no external APIs). Ensure the GGUF in start.sh is a multimodal model that supports <img> tags.",
95
  )
96
 
 
 
 
 
 
97
  if __name__ == "__main__":
98
+ iface.launch(server_name="0.0.0.0", server_port=7860)