Hug0endob commited on
Commit
7766a5c
·
verified ·
1 Parent(s): f275d7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -86
app.py CHANGED
@@ -1,135 +1,103 @@
1
- import os, torch
2
- from transformers import AutoProcessor, LlavaForConditionalGeneration
3
- import gradio as gr
4
- from PIL import Image, ImageSequence
5
  import requests
6
  from io import BytesIO
 
 
 
 
 
7
 
8
- # ---- 1️⃣ Use a public repo ----
9
- MODEL_NAME = "llava-hf/joycaption-llama3.1-8b" # public version
10
 
11
- processor = AutoProcessor.from_pretrained(MODEL_NAME)
 
 
12
  llava_model = LlavaForConditionalGeneration.from_pretrained(
13
  MODEL_NAME,
14
  device_map="cpu",
15
  torch_dtype=torch.bfloat16,
 
16
  )
17
  llava_model.eval()
18
 
19
- # -------------------------------------------------
20
- # Helper: download a file from a URL
21
- # -------------------------------------------------
22
- def download_bytes(url: str) -> bytes:
23
- resp = requests.get(url, stream=True, timeout=30)
24
  resp.raise_for_status()
25
  return resp.content
26
 
27
- # -------------------------------------------------
28
- # Helper: convert MP4 → GIF using ezgif.com (public API)
29
- # -------------------------------------------------
30
  def mp4_to_gif(mp4_bytes: bytes) -> bytes:
31
- """
32
- Sends the MP4 bytes to ezgif.com and returns the resulting GIF bytes.
33
- The API is undocumented but works via a simple multipart POST.
34
- """
35
  files = {"new-file": ("video.mp4", mp4_bytes, "video/mp4")}
36
- # ezgif.com endpoint for MP4 → GIF conversion
37
  resp = requests.post(
38
  "https://s.ezgif.com/video-to-gif",
39
  files=files,
40
  data={"file": "video.mp4"},
41
- timeout=60,
42
  )
43
  resp.raise_for_status()
44
-
45
- # The response HTML contains a link to the generated GIF.
46
- # We extract the first <img src="..."> that ends with .gif
47
- import re
48
-
49
  match = re.search(r'<img[^>]+src="([^"]+\.gif)"', resp.text)
 
 
 
50
  if not match:
51
  raise RuntimeError("Failed to extract GIF URL from ezgif response")
52
  gif_url = match.group(1)
53
-
54
- # ezgif serves the GIF from a relative path; make it absolute
55
  if gif_url.startswith("//"):
56
  gif_url = "https:" + gif_url
57
  elif gif_url.startswith("/"):
58
  gif_url = "https://s.ezgif.com" + gif_url
59
-
60
- gif_resp = requests.get(gif_url, timeout=30)
61
  gif_resp.raise_for_status()
62
  return gif_resp.content
63
 
64
- # -------------------------------------------------
65
- # Main inference function
66
- # -------------------------------------------------
67
- def generate_caption_from_url(url: str, prompt: str = "Describe the image.") -> str:
68
- """
69
- 1. Download the resource.
70
- 2. If it is an MP4 → convert to GIF.
71
- 3. Load the first frame of the image/GIF.
72
- 4. Run JoyCaption and return the caption.
73
- """
74
- # -----------------------------------------------------------------
75
- # 1️⃣ Download raw bytes
76
- # -----------------------------------------------------------------
77
- raw = download_bytes(url)
78
-
79
- # -----------------------------------------------------------------
80
- # 2️⃣ Determine type & possibly convert MP4 → GIF
81
- # -----------------------------------------------------------------
82
- lower_url = url.lower()
83
- if lower_url.endswith(".mp4"):
84
- # Convert video to GIF
85
- raw = mp4_to_gif(raw)
86
- # After conversion we treat it as a GIF
87
- lower_url = ".gif"
88
-
89
- # -----------------------------------------------------------------
90
- # 3️⃣ Load image (first frame for GIFs)
91
- # -----------------------------------------------------------------
92
  img = Image.open(BytesIO(raw))
93
-
94
- # If the file is a multi‑frame GIF, pick the first frame
95
  if getattr(img, "is_animated", False):
96
  img = next(ImageSequence.Iterator(img))
97
-
98
- # Ensure RGB (JoyCaption expects 3‑channel images)
99
  if img.mode != "RGB":
100
  img = img.convert("RGB")
 
101
 
102
- # -----------------------------------------------------------------
103
- # 4️⃣ Run the model
104
- # -----------------------------------------------------------------
105
- inputs = processor(images=img, text=prompt, return_tensors="pt")
106
- inputs = {k: v.to(llava_model.device) for k, v in inputs.items()}
107
-
108
- with torch.no_grad():
109
- out_ids = llava_model.generate(**inputs, max_new_tokens=64)
110
-
111
- caption = processor.decode(out_ids[0], skip_special_tokens=True)
112
- return caption
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
- # -------------------------------------------------
115
- # Gradio UI
116
- # -------------------------------------------------
117
  iface = gr.Interface(
118
  fn=generate_caption_from_url,
119
  inputs=[
120
- gr.Textbox(
121
- label="Image / GIF / MP4 URL",
122
- placeholder="https://example.com/photo.jpg or https://example.com/clip.mp4",
123
- ),
124
  gr.Textbox(label="Prompt (optional)", value="Describe the image."),
125
  ],
126
  outputs=gr.Textbox(label="Generated caption"),
127
- title="JoyCaption URL input (supports GIF & MP4)",
128
- description=(
129
- "Enter a direct URL to an image, an animated GIF, or an MP4 video. "
130
- "MP4 files are automatically converted to GIF via ezgif.com, "
131
- "and the first frame of the GIF is captioned."
132
- ),
133
  allow_flagging="never",
134
  )
135
 
 
1
+ import os
2
+ import re
3
+ import torch
 
4
  import requests
5
  from io import BytesIO
6
+ from PIL import Image, ImageSequence
7
+ from transformers import AutoProcessor, LlavaForConditionalGeneration
8
+ import gradio as gr
9
+
10
+ MODEL_NAME = "fancyfeast/llama-joycaption-beta-one-hf-llava" # public repo
11
 
12
+ # Optional: read HF token from secrets if you set HF_TOKEN in Space (not required for public repo)
13
+ HF_TOKEN = os.getenv("HF_TOKEN")
14
 
15
+ # Load processor and model (CPU only)
16
+ token_arg = {"token": HF_TOKEN} if HF_TOKEN else {}
17
+ processor = AutoProcessor.from_pretrained(MODEL_NAME, **token_arg)
18
  llava_model = LlavaForConditionalGeneration.from_pretrained(
19
  MODEL_NAME,
20
  device_map="cpu",
21
  torch_dtype=torch.bfloat16,
22
+ **token_arg
23
  )
24
  llava_model.eval()
25
 
26
+ def download_bytes(url: str, timeout: int = 30) -> bytes:
27
+ resp = requests.get(url, stream=True, timeout=timeout)
 
 
 
28
  resp.raise_for_status()
29
  return resp.content
30
 
 
 
 
31
  def mp4_to_gif(mp4_bytes: bytes) -> bytes:
 
 
 
 
32
  files = {"new-file": ("video.mp4", mp4_bytes, "video/mp4")}
 
33
  resp = requests.post(
34
  "https://s.ezgif.com/video-to-gif",
35
  files=files,
36
  data={"file": "video.mp4"},
37
+ timeout=120,
38
  )
39
  resp.raise_for_status()
 
 
 
 
 
40
  match = re.search(r'<img[^>]+src="([^"]+\.gif)"', resp.text)
41
+ if not match:
42
+ # try to extract via other img tags
43
+ match = re.search(r'src="([^"]+?/tmp/[^"]+\.gif)"', resp.text)
44
  if not match:
45
  raise RuntimeError("Failed to extract GIF URL from ezgif response")
46
  gif_url = match.group(1)
 
 
47
  if gif_url.startswith("//"):
48
  gif_url = "https:" + gif_url
49
  elif gif_url.startswith("/"):
50
  gif_url = "https://s.ezgif.com" + gif_url
51
+ gif_resp = requests.get(gif_url, timeout=60)
 
52
  gif_resp.raise_for_status()
53
  return gif_resp.content
54
 
55
+ def load_first_frame_from_bytes(raw: bytes) -> Image.Image:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  img = Image.open(BytesIO(raw))
 
 
57
  if getattr(img, "is_animated", False):
58
  img = next(ImageSequence.Iterator(img))
 
 
59
  if img.mode != "RGB":
60
  img = img.convert("RGB")
61
+ return img
62
 
63
+ def generate_caption_from_url(url: str, prompt: str = "Describe the image.") -> str:
64
+ if not url:
65
+ return "No URL provided."
66
+ try:
67
+ raw = download_bytes(url)
68
+ except Exception as e:
69
+ return f"Download error: {e}"
70
+
71
+ lower = url.lower().split("?")[0]
72
+ try:
73
+ if lower.endswith(".mp4") or b"ftyp" in raw[:16].lower():
74
+ try:
75
+ raw = mp4_to_gif(raw)
76
+ except Exception as e:
77
+ return f"MP4→GIF conversion failed: {e}"
78
+ img = load_first_frame_from_bytes(raw)
79
+ except Exception as e:
80
+ return f"Image processing error: {e}"
81
+
82
+ try:
83
+ inputs = processor(images=img, text=prompt, return_tensors="pt")
84
+ inputs = {k: v.to(llava_model.device) for k, v in inputs.items()}
85
+ with torch.no_grad():
86
+ out_ids = llava_model.generate(**inputs, max_new_tokens=128)
87
+ caption = processor.decode(out_ids[0], skip_special_tokens=True)
88
+ return caption
89
+ except Exception as e:
90
+ return f"Inference error: {e}"
91
 
 
 
 
92
  iface = gr.Interface(
93
  fn=generate_caption_from_url,
94
  inputs=[
95
+ gr.Textbox(label="Image / GIF / MP4 URL", placeholder="https://example.com/photo.jpg"),
 
 
 
96
  gr.Textbox(label="Prompt (optional)", value="Describe the image."),
97
  ],
98
  outputs=gr.Textbox(label="Generated caption"),
99
+ title="JoyCaption (public fancyfeast) - URL input",
100
+ description="Paste a direct link to an image, GIF, or MP4. MP4 files are converted to GIF via ezgif.com; the first frame is captioned.",
 
 
 
 
101
  allow_flagging="never",
102
  )
103