meettilavat commited on
Commit
a17ff38
·
verified ·
1 Parent(s): f454bfa

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +19 -0
  2. app.py +304 -0
  3. requirements.txt +7 -0
Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.13.5-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt ./
6
+
7
+ RUN pip install --no-cache-dir -r requirements.txt
8
+
9
+ COPY . .
10
+
11
+ ENV GRADIO_SERVER_NAME=0.0.0.0
12
+ ENV GRADIO_SERVER_PORT=7860
13
+ ENV PORT=7860
14
+
15
+ EXPOSE 7860
16
+
17
+ HEALTHCHECK CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:7860')" || exit 1
18
+
19
+ ENTRYPOINT ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Iterable, List, Tuple
5
+
6
+ import gradio as gr
7
+ import torch
8
+ from huggingface_hub import snapshot_download
9
+ from PIL import Image
10
+ from transformers import AutoProcessor, Blip2ForConditionalGeneration
11
+
12
+
13
+ def is_writable(path: Path) -> bool:
14
+ try:
15
+ path.mkdir(parents=True, exist_ok=True)
16
+ probe = path / ".probe"
17
+ probe.write_text("ok", encoding="utf-8")
18
+ probe.unlink(missing_ok=True)
19
+ return True
20
+ except Exception:
21
+ return False
22
+
23
+
24
+ def pick_writable_base() -> Path:
25
+ for candidate in (
26
+ os.getenv("SPACE_PERSISTENT_DIR"),
27
+ "/data",
28
+ "/app",
29
+ "/tmp",
30
+ ):
31
+ if candidate and is_writable(Path(candidate)):
32
+ return Path(candidate)
33
+ return Path("/tmp")
34
+
35
+
36
+ def set_env_dir(key: str, path: Path) -> None:
37
+ path.mkdir(parents=True, exist_ok=True)
38
+ os.environ[key] = str(path)
39
+
40
+
41
+ BASE_DIR = pick_writable_base()
42
+
43
+
44
+ set_env_dir("HOME", BASE_DIR)
45
+ set_env_dir("XDG_CACHE_HOME", BASE_DIR / ".cache")
46
+ set_env_dir("HF_HOME", BASE_DIR / ".cache" / "huggingface")
47
+ set_env_dir("TRANSFORMERS_CACHE", BASE_DIR / ".cache" / "huggingface" / "transformers")
48
+ set_env_dir("HF_HUB_CACHE", BASE_DIR / ".cache" / "huggingface" / "hub")
49
+
50
+ os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
51
+ os.environ["OMP_NUM_THREADS"] = "2"
52
+ os.environ["MKL_NUM_THREADS"] = "2"
53
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
54
+
55
+ torch.set_num_threads(2)
56
+
57
+
58
+ MODEL_REPO = "meettilavat/imagecaptioning"
59
+ SUBFOLDER_PREFIX = "outputs/blip2_full_ft_stage2"
60
+ LOCAL_DIR = Path(os.environ["HF_HOME"]) / "models" / "imagecaptioning"
61
+ DEFAULT_PROMPT = "Describe the image in detail."
62
+
63
+
64
+ def _allow_patterns() -> Iterable[str]:
65
+ yield f"{SUBFOLDER_PREFIX}/model/config.json"
66
+ yield f"{SUBFOLDER_PREFIX}/model/generation_config.json"
67
+ yield f"{SUBFOLDER_PREFIX}/model/model.safetensors"
68
+ yield f"{SUBFOLDER_PREFIX}/model/model.safetensors.index.json"
69
+ yield f"{SUBFOLDER_PREFIX}/model/model-*.safetensors"
70
+ yield f"{SUBFOLDER_PREFIX}/processor/*"
71
+
72
+
73
+ @functools.lru_cache(maxsize=1)
74
+ def prepare_local_snapshot() -> Path:
75
+ root = snapshot_download(
76
+ repo_id=MODEL_REPO,
77
+ local_dir=str(LOCAL_DIR),
78
+ local_dir_use_symlinks=False,
79
+ allow_patterns=list(_allow_patterns()),
80
+ )
81
+ return Path(root)
82
+
83
+
84
+ @functools.lru_cache(maxsize=1)
85
+ def load_model() -> Tuple[AutoProcessor, Blip2ForConditionalGeneration, torch.device, torch.dtype]:
86
+ repo_root = prepare_local_snapshot()
87
+ base = repo_root / SUBFOLDER_PREFIX
88
+ processor_dir = base / "processor"
89
+ model_dir = base / "model"
90
+
91
+ device = torch.device("cpu")
92
+ dtype: torch.dtype = torch.bfloat16
93
+ processor = AutoProcessor.from_pretrained(processor_dir)
94
+ try:
95
+ model = Blip2ForConditionalGeneration.from_pretrained(
96
+ model_dir,
97
+ torch_dtype=dtype,
98
+ low_cpu_mem_usage=True,
99
+ )
100
+ except Exception:
101
+ dtype = torch.float32
102
+ model = Blip2ForConditionalGeneration.from_pretrained(
103
+ model_dir,
104
+ torch_dtype=dtype,
105
+ low_cpu_mem_usage=True,
106
+ )
107
+ model = model.to(device).eval()
108
+ return processor, model, device, dtype
109
+
110
+
111
+ def generate_caption(
112
+ processor: AutoProcessor,
113
+ model: Blip2ForConditionalGeneration,
114
+ device: torch.device,
115
+ dtype: torch.dtype,
116
+ image: Image.Image,
117
+ prompt: str,
118
+ max_new_tokens: int,
119
+ num_beams: int,
120
+ ) -> str:
121
+ inputs = processor(images=image, text=prompt, return_tensors="pt")
122
+ pixel_values = inputs["pixel_values"].to(device=device, dtype=dtype)
123
+ input_ids = inputs.get("input_ids")
124
+ attention_mask = inputs.get("attention_mask")
125
+
126
+ if input_ids is not None:
127
+ input_ids = input_ids.to(device)
128
+ if attention_mask is not None:
129
+ attention_mask = attention_mask.to(device)
130
+
131
+ with torch.inference_mode():
132
+ generated = model.generate(
133
+ pixel_values=pixel_values,
134
+ input_ids=input_ids,
135
+ attention_mask=attention_mask,
136
+ max_new_tokens=max_new_tokens,
137
+ num_beams=num_beams,
138
+ do_sample=False,
139
+ )
140
+ return processor.batch_decode(generated, skip_special_tokens=True)[0].strip()
141
+
142
+
143
+ def batched_predictions(
144
+ processor: AutoProcessor,
145
+ model: Blip2ForConditionalGeneration,
146
+ device: torch.device,
147
+ dtype: torch.dtype,
148
+ image: Image.Image,
149
+ prompt: str,
150
+ max_new_tokens: int,
151
+ beam_options: List[int],
152
+ ) -> List[Tuple[int, str]]:
153
+ outputs: List[Tuple[int, str]] = []
154
+ for beams in beam_options:
155
+ caption = generate_caption(
156
+ processor,
157
+ model,
158
+ device,
159
+ dtype,
160
+ image,
161
+ prompt,
162
+ max_new_tokens,
163
+ beams,
164
+ )
165
+ outputs.append((beams, caption))
166
+ return outputs
167
+
168
+
169
+ processor, model, device, dtype = load_model()
170
+
171
+
172
+ def run_inference(
173
+ image: Image.Image,
174
+ prompt: str,
175
+ max_new_tokens: int,
176
+ beam_mode: str,
177
+ single_beam: int,
178
+ compare_beams: List[str],
179
+ ) -> str:
180
+ if image is None:
181
+ raise gr.Error("Please upload an image first.")
182
+
183
+ clean_prompt = (prompt or "").strip() or DEFAULT_PROMPT
184
+
185
+ if beam_mode == "Single":
186
+ beam_list = [int(single_beam or 4)]
187
+ else:
188
+ default_options = [2, 4, 6]
189
+ if not compare_beams:
190
+ beam_list = default_options
191
+ else:
192
+ deduped = []
193
+ for value in compare_beams:
194
+ beam = int(value)
195
+ if beam not in deduped:
196
+ deduped.append(beam)
197
+ if len(deduped) == 4:
198
+ break
199
+ beam_list = deduped or default_options
200
+
201
+ results = batched_predictions(
202
+ processor,
203
+ model,
204
+ device,
205
+ dtype,
206
+ image.convert("RGB"),
207
+ clean_prompt,
208
+ max_new_tokens,
209
+ beam_list,
210
+ )
211
+
212
+ blocks = []
213
+ for beams, text in results:
214
+ blocks.append(f"**Beam width {beams}**\n{text}")
215
+ return "\n\n".join(blocks)
216
+
217
+
218
+ def update_beam_visibility(choice: str):
219
+ single_visible = choice == "Single"
220
+ compare_visible = choice == "Compare"
221
+ return (
222
+ gr.Slider.update(visible=single_visible),
223
+ gr.CheckboxGroup.update(visible=compare_visible),
224
+ )
225
+
226
+
227
+ with gr.Blocks(title="BLIP-2 Image Captioning") as demo:
228
+ gr.Markdown("# BLIP-2 Image Captioning (H200 fine-tuned)")
229
+ gr.Markdown(
230
+ "Upload an image, tweak decoding settings, and optionally compare beam widths side by side."
231
+ )
232
+
233
+ with gr.Row():
234
+ with gr.Column(scale=6, min_width=320):
235
+ image_input = gr.Image(
236
+ label="Upload an image",
237
+ type="pil",
238
+ image_mode="RGB",
239
+ )
240
+ prompt_input = gr.Textbox(
241
+ label="Prompt",
242
+ value=DEFAULT_PROMPT,
243
+ lines=3,
244
+ placeholder="Describe the instruction for BLIP-2",
245
+ )
246
+ max_tokens_input = gr.Slider(
247
+ label="Max new tokens",
248
+ minimum=16,
249
+ maximum=128,
250
+ step=8,
251
+ value=56,
252
+ )
253
+ beam_mode_input = gr.Radio(
254
+ label="Beam mode",
255
+ choices=["Single", "Compare"],
256
+ value="Single",
257
+ info="Use a single beam width or compare several options simultaneously.",
258
+ )
259
+ single_beam_slider = gr.Slider(
260
+ label="Beam width",
261
+ minimum=1,
262
+ maximum=8,
263
+ step=1,
264
+ value=4,
265
+ )
266
+ compare_beams_group = gr.CheckboxGroup(
267
+ label="Select beam widths",
268
+ choices=[str(i) for i in range(1, 9)],
269
+ value=["2", "4", "6"],
270
+ interactive=True,
271
+ visible=False,
272
+ )
273
+ run_button = gr.Button("Generate caption(s)")
274
+
275
+ with gr.Column(scale=9):
276
+ caption_output = gr.Markdown(value="Upload an image to preview captions.")
277
+ gr.Markdown(
278
+ f"Running inference on {device.type.upper()} with dtype {dtype}. "
279
+ "Compare beams to balance diversity vs. precision."
280
+ )
281
+
282
+ beam_mode_input.change(
283
+ fn=update_beam_visibility,
284
+ inputs=beam_mode_input,
285
+ outputs=[single_beam_slider, compare_beams_group],
286
+ )
287
+
288
+ run_button.click(
289
+ fn=run_inference,
290
+ inputs=[
291
+ image_input,
292
+ prompt_input,
293
+ max_tokens_input,
294
+ beam_mode_input,
295
+ single_beam_slider,
296
+ compare_beams_group,
297
+ ],
298
+ outputs=caption_output,
299
+ api_name="generate",
300
+ )
301
+
302
+
303
+ if __name__ == "__main__":
304
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch>=2.3,<2.8
2
+ transformers>=4.56
3
+ huggingface_hub>=0.24
4
+ timm>=1.0.19
5
+ sentencepiece>=0.2.1
6
+ gradio>=4.44
7
+ Pillow>=10.4