Hug0endob commited on
Commit
7aed240
·
verified ·
1 Parent(s): 7e27c34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -13
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
- import re
3
  import io
 
4
  import torch
5
  import requests
6
  from PIL import Image, ImageSequence
@@ -10,7 +10,6 @@ import gradio as gr
10
  MODEL_NAME = "fancyfeast/llama-joycaption-beta-one-hf-llava"
11
  HF_TOKEN = os.getenv("HF_TOKEN") # optional
12
 
13
- # Helper: download bytes safely
14
  def download_bytes(url: str, timeout: int = 30) -> bytes:
15
  with requests.get(url, stream=True, timeout=timeout) as resp:
16
  resp.raise_for_status()
@@ -49,23 +48,22 @@ def load_first_frame_from_bytes(raw: bytes) -> Image.Image:
49
 
50
  # Load processor + model
51
  token_arg = {"use_auth_token": HF_TOKEN} if HF_TOKEN else {}
52
- # Some HF model variants require trust_remote_code and num_additional_image_tokens
53
  processor = AutoProcessor.from_pretrained(
54
  MODEL_NAME,
55
  trust_remote_code=True,
56
- num_additional_image_tokens=1, # safe default for many forks that use a CLS token
57
  **({} if not HF_TOKEN else {"token": HF_TOKEN})
58
  )
 
59
  llava_model = LlavaForConditionalGeneration.from_pretrained(
60
  MODEL_NAME,
61
  device_map="cpu",
62
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
63
  trust_remote_code=True,
64
  **({} if not HF_TOKEN else {"token": HF_TOKEN})
65
  )
66
  llava_model.eval()
67
 
68
- # Main inference
69
  def generate_caption_from_url(url: str, prompt: str = "Describe the image.") -> str:
70
  if not url:
71
  return "No URL provided."
@@ -76,7 +74,6 @@ def generate_caption_from_url(url: str, prompt: str = "Describe the image.") ->
76
 
77
  lower = url.lower().split("?")[0]
78
  try:
79
- # crude MP4 detection
80
  if lower.endswith(".mp4") or raw[:16].lower().find(b"ftyp") != -1:
81
  try:
82
  raw = mp4_to_gif(raw)
@@ -86,14 +83,14 @@ def generate_caption_from_url(url: str, prompt: str = "Describe the image.") ->
86
  except Exception as e:
87
  return f"Image processing error: {e}"
88
 
89
- # Resize to safe resolution expected by many VLMs (adjust if your model docs say otherwise)
90
  try:
91
  img = img.resize((512, 512), resample=Image.BICUBIC)
92
  except Exception:
93
  pass
94
 
95
  try:
96
- # Build conversation/chat input so processor inserts image placeholder correctly
97
  conversation = [
98
  {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}
99
  ]
@@ -104,13 +101,12 @@ def generate_caption_from_url(url: str, prompt: str = "Describe the image.") ->
104
  return_dict=True,
105
  images=img,
106
  )
107
- # Move to model device and match dtype for pixel values
108
  device = llava_model.device
109
  inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
110
  if "pixel_values" in inputs:
111
  inputs["pixel_values"] = inputs["pixel_values"].to(dtype=llava_model.dtype, device=device)
112
 
113
- # Debug shapes (helpful if mismatch persists)
114
  if "pixel_values" in inputs:
115
  print("pixel_values.shape:", inputs["pixel_values"].shape)
116
  if "input_ids" in inputs:
@@ -123,7 +119,6 @@ def generate_caption_from_url(url: str, prompt: str = "Describe the image.") ->
123
  except Exception as e:
124
  return f"Inference error: {e}"
125
 
126
- # Gradio UI
127
  gradio_kwargs = dict(
128
  fn=generate_caption_from_url,
129
  inputs=[
@@ -144,7 +139,6 @@ if __name__ == "__main__":
144
  try:
145
  iface.launch(server_name="0.0.0.0", server_port=7860)
146
  finally:
147
- # close event loop safely in Spaces environment
148
  try:
149
  import asyncio
150
  loop = asyncio.get_event_loop()
 
1
  import os
 
2
  import io
3
+ import re
4
  import torch
5
  import requests
6
  from PIL import Image, ImageSequence
 
10
  MODEL_NAME = "fancyfeast/llama-joycaption-beta-one-hf-llava"
11
  HF_TOKEN = os.getenv("HF_TOKEN") # optional
12
 
 
13
  def download_bytes(url: str, timeout: int = 30) -> bytes:
14
  with requests.get(url, stream=True, timeout=timeout) as resp:
15
  resp.raise_for_status()
 
48
 
49
  # Load processor + model
50
  token_arg = {"use_auth_token": HF_TOKEN} if HF_TOKEN else {}
 
51
  processor = AutoProcessor.from_pretrained(
52
  MODEL_NAME,
53
  trust_remote_code=True,
54
+ num_additional_image_tokens=1,
55
  **({} if not HF_TOKEN else {"token": HF_TOKEN})
56
  )
57
+ # Use float32 on CPU; if CPU-only, torch.bfloat16 may not be supported
58
  llava_model = LlavaForConditionalGeneration.from_pretrained(
59
  MODEL_NAME,
60
  device_map="cpu",
61
+ torch_dtype=torch.float32,
62
  trust_remote_code=True,
63
  **({} if not HF_TOKEN else {"token": HF_TOKEN})
64
  )
65
  llava_model.eval()
66
 
 
67
  def generate_caption_from_url(url: str, prompt: str = "Describe the image.") -> str:
68
  if not url:
69
  return "No URL provided."
 
74
 
75
  lower = url.lower().split("?")[0]
76
  try:
 
77
  if lower.endswith(".mp4") or raw[:16].lower().find(b"ftyp") != -1:
78
  try:
79
  raw = mp4_to_gif(raw)
 
83
  except Exception as e:
84
  return f"Image processing error: {e}"
85
 
86
+ # Resize to a conservative size (512) expected by many VLMs
87
  try:
88
  img = img.resize((512, 512), resample=Image.BICUBIC)
89
  except Exception:
90
  pass
91
 
92
  try:
93
+ # Use chat-like conversation so processor inserts image token correctly
94
  conversation = [
95
  {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}
96
  ]
 
101
  return_dict=True,
102
  images=img,
103
  )
 
104
  device = llava_model.device
105
  inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
106
  if "pixel_values" in inputs:
107
  inputs["pixel_values"] = inputs["pixel_values"].to(dtype=llava_model.dtype, device=device)
108
 
109
+ # Debug prints (will appear in Space logs)
110
  if "pixel_values" in inputs:
111
  print("pixel_values.shape:", inputs["pixel_values"].shape)
112
  if "input_ids" in inputs:
 
119
  except Exception as e:
120
  return f"Inference error: {e}"
121
 
 
122
  gradio_kwargs = dict(
123
  fn=generate_caption_from_url,
124
  inputs=[
 
139
  try:
140
  iface.launch(server_name="0.0.0.0", server_port=7860)
141
  finally:
 
142
  try:
143
  import asyncio
144
  loop = asyncio.get_event_loop()