Hug0endob commited on
Commit
851e8b5
·
verified ·
1 Parent(s): b6a2d55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -25
app.py CHANGED
@@ -1,57 +1,139 @@
 
1
  import torch
 
 
 
2
  from transformers import AutoProcessor, LlavaForConditionalGeneration
3
  import gradio as gr
4
- from PIL import Image
5
- import requests
6
 
7
  # -------------------------------------------------
8
- # Model identifier (replace if you fork or use a different checkpoint)
9
  # -------------------------------------------------
10
- MODEL_NAME = "fpgaminer/joycaption-llama3.1-8b" # 8‑B checkpoint fits comfortably on CPU
11
 
12
- # -------------------------------------------------
13
- # Load processor and model (CPU only)
14
- # -------------------------------------------------
15
  processor = AutoProcessor.from_pretrained(MODEL_NAME)
16
-
17
- # `device_map="cpu"` forces everything onto the CPU
18
  llava_model = LlavaForConditionalGeneration.from_pretrained(
19
  MODEL_NAME,
20
  device_map="cpu",
21
- torch_dtype=torch.bfloat16, # native dtype for this model
22
  )
23
-
24
  llava_model.eval()
25
 
26
  # -------------------------------------------------
27
- # Inference function used by Gradio
28
  # -------------------------------------------------
29
- def generate_caption(image: Image.Image, prompt: str = "Describe the image.") -> str:
30
- # Prepare inputs for the model
31
- inputs = processor(images=image, text=prompt, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  inputs = {k: v.to(llava_model.device) for k, v in inputs.items()}
33
 
34
- # Generate up to 64 new tokens (adjust if you want longer captions)
35
  with torch.no_grad():
36
- output_ids = llava_model.generate(**inputs, max_new_tokens=64)
37
 
38
- # Decode to plain text
39
- caption = processor.decode(output_ids[0], skip_special_tokens=True)
40
  return caption
41
 
42
  # -------------------------------------------------
43
  # Gradio UI
44
  # -------------------------------------------------
45
  iface = gr.Interface(
46
- fn=generate_caption,
47
  inputs=[
48
- gr.Image(type="pil", label="Upload an image"),
49
- gr.Textbox(label="Prompt (optional)", value="Describe the image.")
 
 
 
50
  ],
51
  outputs=gr.Textbox(label="Generated caption"),
52
- title="JoyCaption (CPU‑only) Demo",
53
- description="Upload an image and let the JoyCaption model generate a caption. Runs entirely on the free CPU tier.",
54
- allow_flagging="never"
 
 
 
 
55
  )
56
 
57
  if __name__ == "__main__":
 
1
+ import os
2
  import torch
3
+ import requests
4
+ from io import BytesIO
5
+ from PIL import Image, ImageSequence
6
  from transformers import AutoProcessor, LlavaForConditionalGeneration
7
  import gradio as gr
 
 
8
 
9
  # -------------------------------------------------
10
+ # Model configuration (CPU‑only)
11
  # -------------------------------------------------
12
+ MODEL_NAME = "fpgaminer/joycaption-llama3.1-8b"
13
 
 
 
 
14
  processor = AutoProcessor.from_pretrained(MODEL_NAME)
 
 
15
  llava_model = LlavaForConditionalGeneration.from_pretrained(
16
  MODEL_NAME,
17
  device_map="cpu",
18
+ torch_dtype=torch.bfloat16,
19
  )
 
20
  llava_model.eval()
21
 
22
  # -------------------------------------------------
23
+ # Helper: download a file from a URL
24
  # -------------------------------------------------
25
+ def download_bytes(url: str) -> bytes:
26
+ resp = requests.get(url, stream=True, timeout=30)
27
+ resp.raise_for_status()
28
+ return resp.content
29
+
30
+ # -------------------------------------------------
31
+ # Helper: convert MP4 → GIF using ezgif.com (public API)
32
+ # -------------------------------------------------
33
+ def mp4_to_gif(mp4_bytes: bytes) -> bytes:
34
+ """
35
+ Sends the MP4 bytes to ezgif.com and returns the resulting GIF bytes.
36
+ The API is undocumented but works via a simple multipart POST.
37
+ """
38
+ files = {"new-file": ("video.mp4", mp4_bytes, "video/mp4")}
39
+ # ezgif.com endpoint for MP4 → GIF conversion
40
+ resp = requests.post(
41
+ "https://s.ezgif.com/video-to-gif",
42
+ files=files,
43
+ data={"file": "video.mp4"},
44
+ timeout=60,
45
+ )
46
+ resp.raise_for_status()
47
+
48
+ # The response HTML contains a link to the generated GIF.
49
+ # We extract the first <img src="..."> that ends with .gif
50
+ import re
51
+
52
+ match = re.search(r'<img[^>]+src="([^"]+\.gif)"', resp.text)
53
+ if not match:
54
+ raise RuntimeError("Failed to extract GIF URL from ezgif response")
55
+ gif_url = match.group(1)
56
+
57
+ # ezgif serves the GIF from a relative path; make it absolute
58
+ if gif_url.startswith("//"):
59
+ gif_url = "https:" + gif_url
60
+ elif gif_url.startswith("/"):
61
+ gif_url = "https://s.ezgif.com" + gif_url
62
+
63
+ gif_resp = requests.get(gif_url, timeout=30)
64
+ gif_resp.raise_for_status()
65
+ return gif_resp.content
66
+
67
+ # -------------------------------------------------
68
+ # Main inference function
69
+ # -------------------------------------------------
70
+ def generate_caption_from_url(url: str, prompt: str = "Describe the image.") -> str:
71
+ """
72
+ 1. Download the resource.
73
+ 2. If it is an MP4 → convert to GIF.
74
+ 3. Load the first frame of the image/GIF.
75
+ 4. Run JoyCaption and return the caption.
76
+ """
77
+ # -----------------------------------------------------------------
78
+ # 1️⃣ Download raw bytes
79
+ # -----------------------------------------------------------------
80
+ raw = download_bytes(url)
81
+
82
+ # -----------------------------------------------------------------
83
+ # 2️⃣ Determine type & possibly convert MP4 → GIF
84
+ # -----------------------------------------------------------------
85
+ lower_url = url.lower()
86
+ if lower_url.endswith(".mp4"):
87
+ # Convert video to GIF
88
+ raw = mp4_to_gif(raw)
89
+ # After conversion we treat it as a GIF
90
+ lower_url = ".gif"
91
+
92
+ # -----------------------------------------------------------------
93
+ # 3️⃣ Load image (first frame for GIFs)
94
+ # -----------------------------------------------------------------
95
+ img = Image.open(BytesIO(raw))
96
+
97
+ # If the file is a multi‑frame GIF, pick the first frame
98
+ if getattr(img, "is_animated", False):
99
+ img = next(ImageSequence.Iterator(img))
100
+
101
+ # Ensure RGB (JoyCaption expects 3‑channel images)
102
+ if img.mode != "RGB":
103
+ img = img.convert("RGB")
104
+
105
+ # -----------------------------------------------------------------
106
+ # 4️⃣ Run the model
107
+ # -----------------------------------------------------------------
108
+ inputs = processor(images=img, text=prompt, return_tensors="pt")
109
  inputs = {k: v.to(llava_model.device) for k, v in inputs.items()}
110
 
 
111
  with torch.no_grad():
112
+ out_ids = llava_model.generate(**inputs, max_new_tokens=64)
113
 
114
+ caption = processor.decode(out_ids[0], skip_special_tokens=True)
 
115
  return caption
116
 
117
  # -------------------------------------------------
118
  # Gradio UI
119
  # -------------------------------------------------
120
  iface = gr.Interface(
121
+ fn=generate_caption_from_url,
122
  inputs=[
123
+ gr.Textbox(
124
+ label="Image / GIF / MP4 URL",
125
+ placeholder="https://example.com/photo.jpg or https://example.com/clip.mp4",
126
+ ),
127
+ gr.Textbox(label="Prompt (optional)", value="Describe the image."),
128
  ],
129
  outputs=gr.Textbox(label="Generated caption"),
130
+ title="JoyCaption – URL input (supports GIF & MP4)",
131
+ description=(
132
+ "Enter a direct URL to an image, an animated GIF, or an MP4 video. "
133
+ "MP4 files are automatically converted to GIF via ezgif.com, "
134
+ "and the first frame of the GIF is captioned."
135
+ ),
136
+ allow_flagging="never",
137
  )
138
 
139
  if __name__ == "__main__":