import functools import os from pathlib import Path from typing import Iterable, List, Tuple import gradio as gr import torch from huggingface_hub import snapshot_download from PIL import Image from transformers import AutoProcessor, Blip2ForConditionalGeneration def is_writable(path: Path) -> bool: try: path.mkdir(parents=True, exist_ok=True) probe = path / ".probe" probe.write_text("ok", encoding="utf-8") probe.unlink(missing_ok=True) return True except Exception: return False def pick_writable_base() -> Path: for candidate in ( os.getenv("SPACE_PERSISTENT_DIR"), "/data", "/app", "/tmp", ): if candidate and is_writable(Path(candidate)): return Path(candidate) return Path("/tmp") def set_env_dir(key: str, path: Path) -> None: path.mkdir(parents=True, exist_ok=True) os.environ[key] = str(path) BASE_DIR = pick_writable_base() set_env_dir("HOME", BASE_DIR) set_env_dir("XDG_CACHE_HOME", BASE_DIR / ".cache") set_env_dir("HF_HOME", BASE_DIR / ".cache" / "huggingface") set_env_dir("TRANSFORMERS_CACHE", BASE_DIR / ".cache" / "huggingface" / "transformers") set_env_dir("HF_HUB_CACHE", BASE_DIR / ".cache" / "huggingface" / "hub") os.environ["GRADIO_ANALYTICS_ENABLED"] = "false" os.environ["OMP_NUM_THREADS"] = "2" os.environ["MKL_NUM_THREADS"] = "2" os.environ["TOKENIZERS_PARALLELISM"] = "false" torch.set_num_threads(2) MODEL_REPO = "meettilavat/imagecaptioning" SUBFOLDER_PREFIX = "outputs/blip2_full_ft_stage2" LOCAL_DIR = Path(os.environ["HF_HOME"]) / "models" / "imagecaptioning" DEFAULT_PROMPT = "Describe the image in detail." SPINNER_MARKUP = """
Generating caption...
""".strip() SPINNER_CONTAINER_CSS = """ """.strip() def _allow_patterns() -> Iterable[str]: yield f"{SUBFOLDER_PREFIX}/model/config.json" yield f"{SUBFOLDER_PREFIX}/model/generation_config.json" yield f"{SUBFOLDER_PREFIX}/model/model.safetensors" yield f"{SUBFOLDER_PREFIX}/model/model.safetensors.index.json" yield f"{SUBFOLDER_PREFIX}/model/model-*.safetensors" yield f"{SUBFOLDER_PREFIX}/processor/*" @functools.lru_cache(maxsize=1) def prepare_local_snapshot() -> Path: root = snapshot_download( repo_id=MODEL_REPO, local_dir=str(LOCAL_DIR), local_dir_use_symlinks=False, allow_patterns=list(_allow_patterns()), ) return Path(root) @functools.lru_cache(maxsize=1) def load_model() -> Tuple[AutoProcessor, Blip2ForConditionalGeneration, torch.device, torch.dtype]: repo_root = prepare_local_snapshot() base = repo_root / SUBFOLDER_PREFIX processor_dir = base / "processor" model_dir = base / "model" device = torch.device("cpu") dtype: torch.dtype = torch.bfloat16 processor = AutoProcessor.from_pretrained(processor_dir) try: model = Blip2ForConditionalGeneration.from_pretrained( model_dir, torch_dtype=dtype, low_cpu_mem_usage=True, ) except Exception: dtype = torch.float32 model = Blip2ForConditionalGeneration.from_pretrained( model_dir, torch_dtype=dtype, low_cpu_mem_usage=True, ) model = model.to(device).eval() return processor, model, device, dtype def generate_caption( processor: AutoProcessor, model: Blip2ForConditionalGeneration, device: torch.device, dtype: torch.dtype, image: Image.Image, prompt: str, max_new_tokens: int, num_beams: int, ) -> str: inputs = processor(images=image, text=prompt, return_tensors="pt") pixel_values = inputs["pixel_values"].to(device=device, dtype=dtype) input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask") if input_ids is not None: input_ids = input_ids.to(device) if attention_mask is not None: attention_mask = attention_mask.to(device) with torch.inference_mode(): generated = model.generate( pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False, ) return processor.batch_decode(generated, skip_special_tokens=True)[0].strip() def batched_predictions( processor: AutoProcessor, model: Blip2ForConditionalGeneration, device: torch.device, dtype: torch.dtype, image: Image.Image, prompt: str, max_new_tokens: int, beam_options: List[int], ) -> List[Tuple[int, str]]: outputs: List[Tuple[int, str]] = [] for beams in beam_options: caption = generate_caption( processor, model, device, dtype, image, prompt, max_new_tokens, beams, ) outputs.append((beams, caption)) return outputs processor, model, device, dtype = load_model() def run_inference( image: Image.Image, prompt: str, max_new_tokens: int, beam_mode: str, single_beam: int, compare_beams: List[str], ) -> str: if image is None: raise gr.Error("Please upload an image first.") clean_prompt = (prompt or "").strip() or DEFAULT_PROMPT if beam_mode == "Single": beam_list = [int(single_beam or 4)] else: default_options = [2, 4, 6] if not compare_beams: beam_list = default_options else: deduped = [] for value in compare_beams: beam = int(value) if beam not in deduped: deduped.append(beam) if len(deduped) == 4: break beam_list = deduped or default_options results = batched_predictions( processor, model, device, dtype, image.convert("RGB"), clean_prompt, max_new_tokens, beam_list, ) blocks = [] for beams, text in results: blocks.append(f"**Beam width {beams}**\n{text}") return "\n\n".join(blocks) def update_beam_visibility(choice: str): single_visible = choice == "Single" compare_visible = choice == "Compare" return ( gr.Slider.update(visible=single_visible), gr.CheckboxGroup.update(visible=compare_visible), ) def show_spinner(): return gr.HTML.update(visible=True) def hide_spinner(): return gr.HTML.update(visible=False) with gr.Blocks(title="BLIP-2 Image Captioning") as demo: gr.Markdown("# BLIP-2 Image Captioning (H200 fine-tuned)") gr.Markdown( "Upload an image, tweak decoding settings, and optionally compare beam widths side by side." ) gr.HTML(SPINNER_CONTAINER_CSS, show_label=False) with gr.Row(): with gr.Column(scale=6, min_width=320): image_input = gr.Image( label="Upload an image", type="pil", image_mode="RGB", ) prompt_input = gr.Textbox( label="Prompt", value=DEFAULT_PROMPT, lines=3, placeholder="Describe the instruction for BLIP-2", ) max_tokens_input = gr.Slider( label="Max new tokens", minimum=16, maximum=128, step=8, value=56, ) beam_mode_input = gr.Radio( label="Beam mode", choices=["Single", "Compare"], value="Single", info="Use a single beam width or compare several options simultaneously.", ) single_beam_slider = gr.Slider( label="Beam width", minimum=1, maximum=8, step=1, value=4, ) compare_beams_group = gr.CheckboxGroup( label="Select beam widths", choices=[str(i) for i in range(1, 9)], value=["2", "4", "6"], interactive=True, visible=False, ) run_button = gr.Button("Generate caption(s)") with gr.Column(scale=9): caption_output = gr.Markdown(value="Upload an image to preview captions.") gr.Markdown( f"Running inference on {device.type.upper()} with dtype {dtype}. " "Compare beams to balance diversity vs. precision." ) spinner_display = gr.HTML( value=SPINNER_MARKUP, visible=False, show_label=False, elem_id="caption-spinner", ) beam_mode_input.change( fn=update_beam_visibility, inputs=beam_mode_input, outputs=[single_beam_slider, compare_beams_group], ) run_event = run_button.click( fn=show_spinner, outputs=spinner_display, show_progress=False, ) run_event = run_event.then( fn=run_inference, inputs=[ image_input, prompt_input, max_tokens_input, beam_mode_input, single_beam_slider, compare_beams_group, ], outputs=caption_output, api_name="generate", ) run_event.then( fn=hide_spinner, outputs=spinner_display, show_progress=False, ) if __name__ == "__main__": demo.launch()