imagecaptioning / app.py
meettilavat's picture
Upload app.py
40818da verified
import functools
import os
from pathlib import Path
from typing import Iterable, List, Tuple
import gradio as gr
import torch
from huggingface_hub import snapshot_download
from PIL import Image
from transformers import AutoProcessor, Blip2ForConditionalGeneration
def is_writable(path: Path) -> bool:
try:
path.mkdir(parents=True, exist_ok=True)
probe = path / ".probe"
probe.write_text("ok", encoding="utf-8")
probe.unlink(missing_ok=True)
return True
except Exception:
return False
def pick_writable_base() -> Path:
for candidate in (
os.getenv("SPACE_PERSISTENT_DIR"),
"/data",
"/app",
"/tmp",
):
if candidate and is_writable(Path(candidate)):
return Path(candidate)
return Path("/tmp")
def set_env_dir(key: str, path: Path) -> None:
path.mkdir(parents=True, exist_ok=True)
os.environ[key] = str(path)
BASE_DIR = pick_writable_base()
set_env_dir("HOME", BASE_DIR)
set_env_dir("XDG_CACHE_HOME", BASE_DIR / ".cache")
set_env_dir("HF_HOME", BASE_DIR / ".cache" / "huggingface")
set_env_dir("TRANSFORMERS_CACHE", BASE_DIR / ".cache" / "huggingface" / "transformers")
set_env_dir("HF_HUB_CACHE", BASE_DIR / ".cache" / "huggingface" / "hub")
os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
os.environ["OMP_NUM_THREADS"] = "2"
os.environ["MKL_NUM_THREADS"] = "2"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch.set_num_threads(2)
MODEL_REPO = "meettilavat/imagecaptioning"
SUBFOLDER_PREFIX = "outputs/blip2_full_ft_stage2"
LOCAL_DIR = Path(os.environ["HF_HOME"]) / "models" / "imagecaptioning"
DEFAULT_PROMPT = "Describe the image in detail."
SPINNER_MARKUP = """
<div class="caption-spinner">
<div class="caption-spinner__loader" aria-hidden="true"></div>
<span role="status">Generating caption...</span>
</div>
<style>
.caption-spinner {
display: flex;
align-items: center;
gap: 0.5rem;
font-size: 0.95rem;
}
.caption-spinner__loader {
width: 20px;
height: 20px;
border: 3px solid rgba(0, 0, 0, 0.25);
border-top-color: rgba(0, 0, 0, 0.75);
border-radius: 50%;
animation: caption-spin 0.75s linear infinite;
}
@keyframes caption-spin {
to {
transform: rotate(360deg);
}
}
</style>
""".strip()
SPINNER_CONTAINER_CSS = """
<style>
#caption-spinner iframe {
min-height: 48px;
height: 48px;
border: none;
overflow: hidden;
}
</style>
""".strip()
def _allow_patterns() -> Iterable[str]:
yield f"{SUBFOLDER_PREFIX}/model/config.json"
yield f"{SUBFOLDER_PREFIX}/model/generation_config.json"
yield f"{SUBFOLDER_PREFIX}/model/model.safetensors"
yield f"{SUBFOLDER_PREFIX}/model/model.safetensors.index.json"
yield f"{SUBFOLDER_PREFIX}/model/model-*.safetensors"
yield f"{SUBFOLDER_PREFIX}/processor/*"
@functools.lru_cache(maxsize=1)
def prepare_local_snapshot() -> Path:
root = snapshot_download(
repo_id=MODEL_REPO,
local_dir=str(LOCAL_DIR),
local_dir_use_symlinks=False,
allow_patterns=list(_allow_patterns()),
)
return Path(root)
@functools.lru_cache(maxsize=1)
def load_model() -> Tuple[AutoProcessor, Blip2ForConditionalGeneration, torch.device, torch.dtype]:
repo_root = prepare_local_snapshot()
base = repo_root / SUBFOLDER_PREFIX
processor_dir = base / "processor"
model_dir = base / "model"
device = torch.device("cpu")
dtype: torch.dtype = torch.bfloat16
processor = AutoProcessor.from_pretrained(processor_dir)
try:
model = Blip2ForConditionalGeneration.from_pretrained(
model_dir,
torch_dtype=dtype,
low_cpu_mem_usage=True,
)
except Exception:
dtype = torch.float32
model = Blip2ForConditionalGeneration.from_pretrained(
model_dir,
torch_dtype=dtype,
low_cpu_mem_usage=True,
)
model = model.to(device).eval()
return processor, model, device, dtype
def generate_caption(
processor: AutoProcessor,
model: Blip2ForConditionalGeneration,
device: torch.device,
dtype: torch.dtype,
image: Image.Image,
prompt: str,
max_new_tokens: int,
num_beams: int,
) -> str:
inputs = processor(images=image, text=prompt, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(device=device, dtype=dtype)
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask")
if input_ids is not None:
input_ids = input_ids.to(device)
if attention_mask is not None:
attention_mask = attention_mask.to(device)
with torch.inference_mode():
generated = model.generate(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
do_sample=False,
)
return processor.batch_decode(generated, skip_special_tokens=True)[0].strip()
def batched_predictions(
processor: AutoProcessor,
model: Blip2ForConditionalGeneration,
device: torch.device,
dtype: torch.dtype,
image: Image.Image,
prompt: str,
max_new_tokens: int,
beam_options: List[int],
) -> List[Tuple[int, str]]:
outputs: List[Tuple[int, str]] = []
for beams in beam_options:
caption = generate_caption(
processor,
model,
device,
dtype,
image,
prompt,
max_new_tokens,
beams,
)
outputs.append((beams, caption))
return outputs
processor, model, device, dtype = load_model()
def run_inference(
image: Image.Image,
prompt: str,
max_new_tokens: int,
beam_mode: str,
single_beam: int,
compare_beams: List[str],
) -> str:
if image is None:
raise gr.Error("Please upload an image first.")
clean_prompt = (prompt or "").strip() or DEFAULT_PROMPT
if beam_mode == "Single":
beam_list = [int(single_beam or 4)]
else:
default_options = [2, 4, 6]
if not compare_beams:
beam_list = default_options
else:
deduped = []
for value in compare_beams:
beam = int(value)
if beam not in deduped:
deduped.append(beam)
if len(deduped) == 4:
break
beam_list = deduped or default_options
results = batched_predictions(
processor,
model,
device,
dtype,
image.convert("RGB"),
clean_prompt,
max_new_tokens,
beam_list,
)
blocks = []
for beams, text in results:
blocks.append(f"**Beam width {beams}**\n{text}")
return "\n\n".join(blocks)
def update_beam_visibility(choice: str):
single_visible = choice == "Single"
compare_visible = choice == "Compare"
return (
gr.Slider.update(visible=single_visible),
gr.CheckboxGroup.update(visible=compare_visible),
)
def show_spinner():
return gr.HTML.update(visible=True)
def hide_spinner():
return gr.HTML.update(visible=False)
with gr.Blocks(title="BLIP-2 Image Captioning") as demo:
gr.Markdown("# BLIP-2 Image Captioning (H200 fine-tuned)")
gr.Markdown(
"Upload an image, tweak decoding settings, and optionally compare beam widths side by side."
)
gr.HTML(SPINNER_CONTAINER_CSS, show_label=False)
with gr.Row():
with gr.Column(scale=6, min_width=320):
image_input = gr.Image(
label="Upload an image",
type="pil",
image_mode="RGB",
)
prompt_input = gr.Textbox(
label="Prompt",
value=DEFAULT_PROMPT,
lines=3,
placeholder="Describe the instruction for BLIP-2",
)
max_tokens_input = gr.Slider(
label="Max new tokens",
minimum=16,
maximum=128,
step=8,
value=56,
)
beam_mode_input = gr.Radio(
label="Beam mode",
choices=["Single", "Compare"],
value="Single",
info="Use a single beam width or compare several options simultaneously.",
)
single_beam_slider = gr.Slider(
label="Beam width",
minimum=1,
maximum=8,
step=1,
value=4,
)
compare_beams_group = gr.CheckboxGroup(
label="Select beam widths",
choices=[str(i) for i in range(1, 9)],
value=["2", "4", "6"],
interactive=True,
visible=False,
)
run_button = gr.Button("Generate caption(s)")
with gr.Column(scale=9):
caption_output = gr.Markdown(value="Upload an image to preview captions.")
gr.Markdown(
f"Running inference on {device.type.upper()} with dtype {dtype}. "
"Compare beams to balance diversity vs. precision."
)
spinner_display = gr.HTML(
value=SPINNER_MARKUP,
visible=False,
show_label=False,
elem_id="caption-spinner",
)
beam_mode_input.change(
fn=update_beam_visibility,
inputs=beam_mode_input,
outputs=[single_beam_slider, compare_beams_group],
)
run_event = run_button.click(
fn=show_spinner,
outputs=spinner_display,
show_progress=False,
)
run_event = run_event.then(
fn=run_inference,
inputs=[
image_input,
prompt_input,
max_tokens_input,
beam_mode_input,
single_beam_slider,
compare_beams_group,
],
outputs=caption_output,
api_name="generate",
)
run_event.then(
fn=hide_spinner,
outputs=spinner_display,
show_progress=False,
)
if __name__ == "__main__":
demo.launch()