Hug0endob commited on
Commit
028a367
·
verified ·
1 Parent(s): 09c7c56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -40
app.py CHANGED
@@ -1,38 +1,20 @@
1
  import os
2
  import re
 
3
  import torch
4
  import requests
5
- from io import BytesIO
6
  from PIL import Image, ImageSequence
7
  from transformers import AutoProcessor, LlavaForConditionalGeneration
8
  import gradio as gr
9
 
10
- # ---------------------------
11
- # Config
12
- # ---------------------------
13
  MODEL_NAME = "fancyfeast/llama-joycaption-beta-one-hf-llava"
14
- HF_TOKEN = os.getenv("HF_TOKEN") # optional secret in Space settings
15
 
16
- # ---------------------------
17
- # Load model & processor
18
- # ---------------------------
19
- token_arg = {"token": HF_TOKEN} if HF_TOKEN else {}
20
- processor = AutoProcessor.from_pretrained(MODEL_NAME, **token_arg)
21
- llava_model = LlavaForConditionalGeneration.from_pretrained(
22
- MODEL_NAME,
23
- device_map="cpu",
24
- torch_dtype=torch.bfloat16,
25
- **token_arg,
26
- )
27
- llava_model.eval()
28
-
29
- # ---------------------------
30
- # Helpers
31
- # ---------------------------
32
  def download_bytes(url: str, timeout: int = 30) -> bytes:
33
- resp = requests.get(url, stream=True, timeout=timeout)
34
- resp.raise_for_status()
35
- return resp.content
36
 
37
  def mp4_to_gif(mp4_bytes: bytes) -> bytes:
38
  files = {"new-file": ("video.mp4", mp4_bytes, "video/mp4")}
@@ -53,21 +35,37 @@ def mp4_to_gif(mp4_bytes: bytes) -> bytes:
53
  gif_url = "https:" + gif_url
54
  elif gif_url.startswith("/"):
55
  gif_url = "https://s.ezgif.com" + gif_url
56
- gif_resp = requests.get(gif_url, timeout=60)
57
- gif_resp.raise_for_status()
58
- return gif_resp.content
59
 
60
  def load_first_frame_from_bytes(raw: bytes) -> Image.Image:
61
- img = Image.open(BytesIO(raw))
62
  if getattr(img, "is_animated", False):
63
  img = next(ImageSequence.Iterator(img))
64
  if img.mode != "RGB":
65
  img = img.convert("RGB")
66
  return img
67
 
68
- # ---------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  # Main inference
70
- # ---------------------------
71
  def generate_caption_from_url(url: str, prompt: str = "Describe the image.") -> str:
72
  if not url:
73
  return "No URL provided."
@@ -78,7 +76,7 @@ def generate_caption_from_url(url: str, prompt: str = "Describe the image.") ->
78
 
79
  lower = url.lower().split("?")[0]
80
  try:
81
- # crude MP4 detection by extension or ftyp box signature
82
  if lower.endswith(".mp4") or raw[:16].lower().find(b"ftyp") != -1:
83
  try:
84
  raw = mp4_to_gif(raw)
@@ -88,9 +86,36 @@ def generate_caption_from_url(url: str, prompt: str = "Describe the image.") ->
88
  except Exception as e:
89
  return f"Image processing error: {e}"
90
 
 
91
  try:
92
- inputs = processor(images=img, text=prompt, return_tensors="pt")
93
- inputs = {k: v.to(llava_model.device) for k, v in inputs.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  with torch.no_grad():
95
  out_ids = llava_model.generate(**inputs, max_new_tokens=128)
96
  caption = processor.decode(out_ids[0], skip_special_tokens=True)
@@ -98,10 +123,7 @@ def generate_caption_from_url(url: str, prompt: str = "Describe the image.") ->
98
  except Exception as e:
99
  return f"Inference error: {e}"
100
 
101
- # ---------------------------
102
- # Gradio UI (compatible init)
103
- # ---------------------------
104
- # Use try/except to support Gradio versions that don't accept allow_flagging
105
  gradio_kwargs = dict(
106
  fn=generate_caption_from_url,
107
  inputs=[
@@ -109,8 +131,8 @@ gradio_kwargs = dict(
109
  gr.Textbox(label="Prompt (optional)", value="Describe the image."),
110
  ],
111
  outputs=gr.Textbox(label="Generated caption"),
112
- title="JoyCaption (fancyfeast) - URL input",
113
- description="Paste a direct link to an image, GIF, or MP4. MP4 files are converted to GIF via ezgif.com; the first frame is captioned.",
114
  )
115
 
116
  try:
@@ -119,4 +141,14 @@ except TypeError:
119
  iface = gr.Interface(**gradio_kwargs)
120
 
121
  if __name__ == "__main__":
122
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import re
3
+ import io
4
  import torch
5
  import requests
 
6
  from PIL import Image, ImageSequence
7
  from transformers import AutoProcessor, LlavaForConditionalGeneration
8
  import gradio as gr
9
 
 
 
 
10
  MODEL_NAME = "fancyfeast/llama-joycaption-beta-one-hf-llava"
11
+ HF_TOKEN = os.getenv("HF_TOKEN") # optional
12
 
13
+ # Helper: download bytes safely
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def download_bytes(url: str, timeout: int = 30) -> bytes:
15
+ with requests.get(url, stream=True, timeout=timeout) as resp:
16
+ resp.raise_for_status()
17
+ return resp.content
18
 
19
  def mp4_to_gif(mp4_bytes: bytes) -> bytes:
20
  files = {"new-file": ("video.mp4", mp4_bytes, "video/mp4")}
 
35
  gif_url = "https:" + gif_url
36
  elif gif_url.startswith("/"):
37
  gif_url = "https://s.ezgif.com" + gif_url
38
+ with requests.get(gif_url, timeout=60) as gif_resp:
39
+ gif_resp.raise_for_status()
40
+ return gif_resp.content
41
 
42
  def load_first_frame_from_bytes(raw: bytes) -> Image.Image:
43
+ img = Image.open(io.BytesIO(raw))
44
  if getattr(img, "is_animated", False):
45
  img = next(ImageSequence.Iterator(img))
46
  if img.mode != "RGB":
47
  img = img.convert("RGB")
48
  return img
49
 
50
+ # Load processor + model
51
+ token_arg = {"use_auth_token": HF_TOKEN} if HF_TOKEN else {}
52
+ # Some HF model variants require trust_remote_code and num_additional_image_tokens
53
+ processor = AutoProcessor.from_pretrained(
54
+ MODEL_NAME,
55
+ trust_remote_code=True,
56
+ num_additional_image_tokens=1, # safe default for many forks that use a CLS token
57
+ **({} if not HF_TOKEN else {"token": HF_TOKEN})
58
+ )
59
+ llava_model = LlavaForConditionalGeneration.from_pretrained(
60
+ MODEL_NAME,
61
+ device_map="cpu",
62
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
63
+ trust_remote_code=True,
64
+ **({} if not HF_TOKEN else {"token": HF_TOKEN})
65
+ )
66
+ llava_model.eval()
67
+
68
  # Main inference
 
69
  def generate_caption_from_url(url: str, prompt: str = "Describe the image.") -> str:
70
  if not url:
71
  return "No URL provided."
 
76
 
77
  lower = url.lower().split("?")[0]
78
  try:
79
+ # crude MP4 detection
80
  if lower.endswith(".mp4") or raw[:16].lower().find(b"ftyp") != -1:
81
  try:
82
  raw = mp4_to_gif(raw)
 
86
  except Exception as e:
87
  return f"Image processing error: {e}"
88
 
89
+ # Resize to safe resolution expected by many VLMs (adjust if your model docs say otherwise)
90
  try:
91
+ img = img.resize((512, 512), resample=Image.BICUBIC)
92
+ except Exception:
93
+ pass
94
+
95
+ try:
96
+ # Build conversation/chat input so processor inserts image placeholder correctly
97
+ conversation = [
98
+ {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}
99
+ ]
100
+ inputs = processor.apply_chat_template(
101
+ conversation,
102
+ add_generation_prompt=True,
103
+ return_tensors="pt",
104
+ return_dict=True,
105
+ images=img,
106
+ )
107
+ # Move to model device and match dtype for pixel values
108
+ device = llava_model.device
109
+ inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
110
+ if "pixel_values" in inputs:
111
+ inputs["pixel_values"] = inputs["pixel_values"].to(dtype=llava_model.dtype, device=device)
112
+
113
+ # Debug shapes (helpful if mismatch persists)
114
+ if "pixel_values" in inputs:
115
+ print("pixel_values.shape:", inputs["pixel_values"].shape)
116
+ if "input_ids" in inputs:
117
+ print("input_ids.shape:", inputs["input_ids"].shape)
118
+
119
  with torch.no_grad():
120
  out_ids = llava_model.generate(**inputs, max_new_tokens=128)
121
  caption = processor.decode(out_ids[0], skip_special_tokens=True)
 
123
  except Exception as e:
124
  return f"Inference error: {e}"
125
 
126
+ # Gradio UI
 
 
 
127
  gradio_kwargs = dict(
128
  fn=generate_caption_from_url,
129
  inputs=[
 
131
  gr.Textbox(label="Prompt (optional)", value="Describe the image."),
132
  ],
133
  outputs=gr.Textbox(label="Generated caption"),
134
+ title="JoyCaption - URL input",
135
+ description="Paste a direct link to an image/GIF/MP4 (MP4 will be converted).",
136
  )
137
 
138
  try:
 
141
  iface = gr.Interface(**gradio_kwargs)
142
 
143
  if __name__ == "__main__":
144
+ try:
145
+ iface.launch(server_name="0.0.0.0", server_port=7860)
146
+ finally:
147
+ # close event loop safely in Spaces environment
148
+ try:
149
+ import asyncio
150
+ loop = asyncio.get_event_loop()
151
+ if not loop.is_closed():
152
+ loop.close()
153
+ except Exception:
154
+ pass