achase25 commited on
Commit
5251d80
·
verified ·
1 Parent(s): adb4cdd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -105
app.py CHANGED
@@ -1,165 +1,112 @@
1
- # CPU-only Hugging Face Space: Image → Caption (Florence-2-base)
2
- # - No story model, only captioning.
3
- # - Florence runs without flash_attn via a small monkey patch.
4
- # - AVIF/HEIF image uploads supported via plugins.
5
- # - Batched Florence processor call (images=[...], padding=True).
6
 
7
  import os
8
- from typing import Dict, Any
9
 
10
  import torch
11
  import gradio as gr
12
  from PIL import Image
13
 
14
- # --- Enable AVIF/HEIF decoding for Pillow (handles .avif, .heic from phones) ---
15
  try:
16
- import pillow_avif # registers AVIF opener on import
17
  except Exception:
18
  pass
19
-
20
  try:
21
  from pillow_heif import register_heif_opener
22
  register_heif_opener()
23
  except Exception:
24
  pass
25
 
26
- from transformers import (
27
- AutoProcessor,
28
- AutoModelForCausalLM,
29
- )
30
 
31
- # -------------------- Config --------------------
32
  CAPTION_MODEL_ID = "microsoft/Florence-2-base"
33
- HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or None
34
-
35
  DEVICE = "cpu"
36
- DTYPE = torch.float32
37
- MAX_IMG_SIDE = int(os.getenv("MAX_IMG_SIDE", "1024")) # bump if you want bigger inputs
38
 
39
- # -------------------- Cache --------------------
40
- _caption_bundle: Dict[str, Any] = {}
41
-
42
- # -------------------- Utils --------------------
43
  def _resize_max(img: Image.Image, max_side: int = MAX_IMG_SIDE) -> Image.Image:
44
  w, h = img.size
45
  if max(w, h) <= max_side:
46
  return img
47
- if w >= h:
48
- new_w = max_side
49
- new_h = int(h * (max_side / w))
50
- else:
51
- new_h = max_side
52
- new_w = int(w * (max_side / h))
53
- return img.resize((new_w, new_h), Image.LANCZOS)
54
 
55
- def _ensure_image(img) -> Image.Image:
56
  if not isinstance(img, Image.Image):
57
- raise gr.Error("Uploaded file is not a valid image.")
58
  return img.convert("RGB")
59
 
60
- # -------------------- Monkey patch: ignore flash_attn for Florence --------------------
61
- from unittest.mock import patch
62
- from transformers.dynamic_module_utils import get_imports as _orig_get_imports
63
-
64
- def _fixed_get_imports(filename):
65
- """Drop flash_attn requirement only for Florence modeling file (CPU-safe)."""
66
- imports = _orig_get_imports(filename)
67
  try:
68
  name = str(filename).lower()
69
  if "florence2" in name or "modeling_florence2.py" in name:
70
- return [imp for imp in imports if imp != "flash_attn"]
71
  except Exception:
72
  pass
73
- return imports
74
 
75
- # -------------------- Load Florence --------------------
76
- def get_caption_bundle() -> Dict[str, Any]:
77
- """Return {'processor': ..., 'model': ...} for Florence-2-base."""
78
- global _caption_bundle
79
- if _caption_bundle:
80
- return _caption_bundle
81
-
82
- processor = AutoProcessor.from_pretrained(
83
- CAPTION_MODEL_ID, trust_remote_code=True, token=HF_TOKEN
84
- )
85
- with patch("transformers.dynamic_module_utils.get_imports", _fixed_get_imports):
86
- model = AutoModelForCausalLM.from_pretrained(
87
  CAPTION_MODEL_ID,
88
  trust_remote_code=True,
89
  token=HF_TOKEN,
90
- attn_implementation="sdpa", # non-flash attention path
91
  torch_dtype=DTYPE,
92
  device_map="cpu",
93
  ).eval()
 
94
 
95
- _caption_bundle = {"processor": processor, "model": model}
96
- return _caption_bundle
97
-
98
- # -------------------- Caption op --------------------
99
  @torch.inference_mode()
100
- def op_caption(image: Image.Image, max_new_tokens: int = 128, num_beams: int = 3) -> str:
101
- """
102
- Florence-2-base caption (CPU):
103
- - Task tag: <MORE_DETAILED_CAPTION>
104
- - Batched call: images=[image], padding=True
105
- - post_process_generation parses the structured output
106
- """
107
- image = _ensure_image(image)
108
- image = _resize_max(image)
109
-
110
- bun = get_caption_bundle()
111
- processor, model = bun["processor"], bun["model"]
112
-
113
- inputs = processor(
114
  text="<MORE_DETAILED_CAPTION>",
115
- images=[image], # batch, even for single image
116
- padding=True, # ensure consistent tensor shapes
117
- return_tensors="pt"
118
  )
119
- # move tensors to device
120
- for k in list(inputs.keys()):
121
- if isinstance(inputs[k], torch.Tensor):
122
- inputs[k] = inputs[k].to(DEVICE)
123
 
124
- generated_ids = model.generate(
125
- **inputs,
126
  max_new_tokens=max_new_tokens,
127
- do_sample=False,
128
  num_beams=num_beams,
 
129
  early_stopping=False,
130
  )
131
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
132
-
133
  parsed = processor.post_process_generation(
134
- generated_text,
135
- task="<MORE_DETAILED_CAPTION>",
136
- image_size=[(image.width, image.height)], # list to match batched input
137
  )
138
- data = parsed[0] if isinstance(parsed, list) and parsed else parsed
139
- caption = (data.get("<MORE_DETAILED_CAPTION>", "") or "").strip()
140
- return caption or "Unable to generate a caption."
141
 
142
- # -------------------- Gradio UI --------------------
143
  def run(image: Image.Image):
144
- if image is None:
145
- raise gr.Error("Upload an image first.")
146
- text = op_caption(image)
147
- return text, "Task: caption • Model: Florence-2-base (CPU)"
148
 
149
- with gr.Blocks(css="footer {visibility:hidden}") as demo:
150
- gr.Markdown("# Image → Caption (CPU-only) — Florence-2-base")
151
  with gr.Row():
152
  with gr.Column():
153
- inp_img = gr.Image(
154
- type="pil",
155
- label="Image",
156
- sources=["upload", "clipboard", "webcam"],
157
- )
158
- go = gr.Button("Caption", variant="primary")
159
  with gr.Column():
160
- out_text = gr.Textbox(label="Caption", lines=10)
161
- status = gr.Markdown()
162
- go.click(run, inputs=[inp_img], outputs=[out_text, status], scroll_to_output=True)
163
 
164
  if __name__ == "__main__":
165
  demo.queue(max_size=8).launch()
 
1
+ # CPU-only: Image → Caption (Florence-2-base), concise build
 
 
 
 
2
 
3
  import os
4
+ from functools import lru_cache
5
 
6
  import torch
7
  import gradio as gr
8
  from PIL import Image
9
 
10
+ # AVIF/HEIF support (optional, safe to ignore if unavailable)
11
  try:
12
+ import pillow_avif # noqa: F401
13
  except Exception:
14
  pass
 
15
  try:
16
  from pillow_heif import register_heif_opener
17
  register_heif_opener()
18
  except Exception:
19
  pass
20
 
21
+ from transformers import AutoProcessor, AutoModelForCausalLM
22
+ from unittest.mock import patch
23
+ from transformers.dynamic_module_utils import get_imports as _orig_get_imports
 
24
 
 
25
  CAPTION_MODEL_ID = "microsoft/Florence-2-base"
26
+ HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
 
27
  DEVICE = "cpu"
28
+ DTYPE = torch.float32
29
+ MAX_IMG_SIDE = int(os.getenv("MAX_IMG_SIDE", "1024"))
30
 
 
 
 
 
31
  def _resize_max(img: Image.Image, max_side: int = MAX_IMG_SIDE) -> Image.Image:
32
  w, h = img.size
33
  if max(w, h) <= max_side:
34
  return img
35
+ r = max_side / max(w, h)
36
+ return img.resize((int(w * r), int(h * r)), Image.LANCZOS)
 
 
 
 
 
37
 
38
+ def _ensure_rgb(img) -> Image.Image:
39
  if not isinstance(img, Image.Image):
40
+ raise gr.Error("Upload a valid image.")
41
  return img.convert("RGB")
42
 
43
+ def _no_flash_attn_get_imports(filename):
44
+ imps = _orig_get_imports(filename)
 
 
 
 
 
45
  try:
46
  name = str(filename).lower()
47
  if "florence2" in name or "modeling_florence2.py" in name:
48
+ return [x for x in imps if x != "flash_attn"]
49
  except Exception:
50
  pass
51
+ return imps
52
 
53
+ @lru_cache(maxsize=1)
54
+ def _load_florence():
55
+ proc = AutoProcessor.from_pretrained(CAPTION_MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
56
+ with patch("transformers.dynamic_module_utils.get_imports", _no_flash_attn_get_imports):
57
+ mdl = AutoModelForCausalLM.from_pretrained(
 
 
 
 
 
 
 
58
  CAPTION_MODEL_ID,
59
  trust_remote_code=True,
60
  token=HF_TOKEN,
61
+ attn_implementation="sdpa", # CPU-safe
62
  torch_dtype=DTYPE,
63
  device_map="cpu",
64
  ).eval()
65
+ return proc, mdl
66
 
 
 
 
 
67
  @torch.inference_mode()
68
+ def caption(image: Image.Image, max_new_tokens: int = 128, num_beams: int = 3) -> str:
69
+ image = _ensure_rgb(_resize_max(image))
70
+ processor, model = _load_florence()
71
+ batch = processor(
 
 
 
 
 
 
 
 
 
 
72
  text="<MORE_DETAILED_CAPTION>",
73
+ images=[image], # batch even for single
74
+ padding=True,
75
+ return_tensors="pt",
76
  )
77
+ # move tensors to CPU device (BatchFeature may contain non-tensors)
78
+ for k, v in list(batch.items()):
79
+ if torch.is_tensor(v):
80
+ batch[k] = v.to(DEVICE)
81
 
82
+ out_ids = model.generate(
83
+ **batch,
84
  max_new_tokens=max_new_tokens,
 
85
  num_beams=num_beams,
86
+ do_sample=False,
87
  early_stopping=False,
88
  )
89
+ gen = processor.batch_decode(out_ids, skip_special_tokens=False)[0]
 
90
  parsed = processor.post_process_generation(
91
+ gen, task="<MORE_DETAILED_CAPTION>", image_size=[(image.width, image.height)]
 
 
92
  )
93
+ data = parsed[0] if isinstance(parsed, list) else parsed
94
+ return (data.get("<MORE_DETAILED_CAPTION>", "") or "Unable to generate a caption.").strip()
 
95
 
 
96
  def run(image: Image.Image):
97
+ txt = caption(image)
98
+ return txt, "Model: Florence-2-base (CPU)"
 
 
99
 
100
+ with gr.Blocks(css="footer{visibility:hidden}") as demo:
101
+ gr.Markdown("# Image → Caption (CPU) — Florence-2-base")
102
  with gr.Row():
103
  with gr.Column():
104
+ img = gr.Image(type="pil", label="Image", sources=["upload", "clipboard", "webcam"])
105
+ btn = gr.Button("Caption", variant="primary")
 
 
 
 
106
  with gr.Column():
107
+ out = gr.Textbox(label="Caption", lines=10)
108
+ status = gr.Markdown()
109
+ btn.click(run, [img], [out, status], scroll_to_output=True)
110
 
111
  if __name__ == "__main__":
112
  demo.queue(max_size=8).launch()