Spaces:
Sleeping
Sleeping
| import functools | |
| import os | |
| from pathlib import Path | |
| from typing import Iterable, List, Tuple | |
| import gradio as gr | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from PIL import Image | |
| from transformers import AutoProcessor, Blip2ForConditionalGeneration | |
| def is_writable(path: Path) -> bool: | |
| try: | |
| path.mkdir(parents=True, exist_ok=True) | |
| probe = path / ".probe" | |
| probe.write_text("ok", encoding="utf-8") | |
| probe.unlink(missing_ok=True) | |
| return True | |
| except Exception: | |
| return False | |
| def pick_writable_base() -> Path: | |
| for candidate in ( | |
| os.getenv("SPACE_PERSISTENT_DIR"), | |
| "/data", | |
| "/app", | |
| "/tmp", | |
| ): | |
| if candidate and is_writable(Path(candidate)): | |
| return Path(candidate) | |
| return Path("/tmp") | |
| def set_env_dir(key: str, path: Path) -> None: | |
| path.mkdir(parents=True, exist_ok=True) | |
| os.environ[key] = str(path) | |
| BASE_DIR = pick_writable_base() | |
| set_env_dir("HOME", BASE_DIR) | |
| set_env_dir("XDG_CACHE_HOME", BASE_DIR / ".cache") | |
| set_env_dir("HF_HOME", BASE_DIR / ".cache" / "huggingface") | |
| set_env_dir("TRANSFORMERS_CACHE", BASE_DIR / ".cache" / "huggingface" / "transformers") | |
| set_env_dir("HF_HUB_CACHE", BASE_DIR / ".cache" / "huggingface" / "hub") | |
| os.environ["GRADIO_ANALYTICS_ENABLED"] = "false" | |
| os.environ["OMP_NUM_THREADS"] = "2" | |
| os.environ["MKL_NUM_THREADS"] = "2" | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| torch.set_num_threads(2) | |
| MODEL_REPO = "meettilavat/imagecaptioning" | |
| SUBFOLDER_PREFIX = "outputs/blip2_full_ft_stage2" | |
| LOCAL_DIR = Path(os.environ["HF_HOME"]) / "models" / "imagecaptioning" | |
| DEFAULT_PROMPT = "Describe the image in detail." | |
| SPINNER_MARKUP = """ | |
| <div class="caption-spinner"> | |
| <div class="caption-spinner__loader" aria-hidden="true"></div> | |
| <span role="status">Generating caption...</span> | |
| </div> | |
| <style> | |
| .caption-spinner { | |
| display: flex; | |
| align-items: center; | |
| gap: 0.5rem; | |
| font-size: 0.95rem; | |
| } | |
| .caption-spinner__loader { | |
| width: 20px; | |
| height: 20px; | |
| border: 3px solid rgba(0, 0, 0, 0.25); | |
| border-top-color: rgba(0, 0, 0, 0.75); | |
| border-radius: 50%; | |
| animation: caption-spin 0.75s linear infinite; | |
| } | |
| @keyframes caption-spin { | |
| to { | |
| transform: rotate(360deg); | |
| } | |
| } | |
| </style> | |
| """.strip() | |
| SPINNER_CONTAINER_CSS = """ | |
| <style> | |
| #caption-spinner iframe { | |
| min-height: 48px; | |
| height: 48px; | |
| border: none; | |
| overflow: hidden; | |
| } | |
| </style> | |
| """.strip() | |
| def _allow_patterns() -> Iterable[str]: | |
| yield f"{SUBFOLDER_PREFIX}/model/config.json" | |
| yield f"{SUBFOLDER_PREFIX}/model/generation_config.json" | |
| yield f"{SUBFOLDER_PREFIX}/model/model.safetensors" | |
| yield f"{SUBFOLDER_PREFIX}/model/model.safetensors.index.json" | |
| yield f"{SUBFOLDER_PREFIX}/model/model-*.safetensors" | |
| yield f"{SUBFOLDER_PREFIX}/processor/*" | |
| def prepare_local_snapshot() -> Path: | |
| root = snapshot_download( | |
| repo_id=MODEL_REPO, | |
| local_dir=str(LOCAL_DIR), | |
| local_dir_use_symlinks=False, | |
| allow_patterns=list(_allow_patterns()), | |
| ) | |
| return Path(root) | |
| def load_model() -> Tuple[AutoProcessor, Blip2ForConditionalGeneration, torch.device, torch.dtype]: | |
| repo_root = prepare_local_snapshot() | |
| base = repo_root / SUBFOLDER_PREFIX | |
| processor_dir = base / "processor" | |
| model_dir = base / "model" | |
| device = torch.device("cpu") | |
| dtype: torch.dtype = torch.bfloat16 | |
| processor = AutoProcessor.from_pretrained(processor_dir) | |
| try: | |
| model = Blip2ForConditionalGeneration.from_pretrained( | |
| model_dir, | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=True, | |
| ) | |
| except Exception: | |
| dtype = torch.float32 | |
| model = Blip2ForConditionalGeneration.from_pretrained( | |
| model_dir, | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=True, | |
| ) | |
| model = model.to(device).eval() | |
| return processor, model, device, dtype | |
| def generate_caption( | |
| processor: AutoProcessor, | |
| model: Blip2ForConditionalGeneration, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| image: Image.Image, | |
| prompt: str, | |
| max_new_tokens: int, | |
| num_beams: int, | |
| ) -> str: | |
| inputs = processor(images=image, text=prompt, return_tensors="pt") | |
| pixel_values = inputs["pixel_values"].to(device=device, dtype=dtype) | |
| input_ids = inputs.get("input_ids") | |
| attention_mask = inputs.get("attention_mask") | |
| if input_ids is not None: | |
| input_ids = input_ids.to(device) | |
| if attention_mask is not None: | |
| attention_mask = attention_mask.to(device) | |
| with torch.inference_mode(): | |
| generated = model.generate( | |
| pixel_values=pixel_values, | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=max_new_tokens, | |
| num_beams=num_beams, | |
| do_sample=False, | |
| ) | |
| return processor.batch_decode(generated, skip_special_tokens=True)[0].strip() | |
| def batched_predictions( | |
| processor: AutoProcessor, | |
| model: Blip2ForConditionalGeneration, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| image: Image.Image, | |
| prompt: str, | |
| max_new_tokens: int, | |
| beam_options: List[int], | |
| ) -> List[Tuple[int, str]]: | |
| outputs: List[Tuple[int, str]] = [] | |
| for beams in beam_options: | |
| caption = generate_caption( | |
| processor, | |
| model, | |
| device, | |
| dtype, | |
| image, | |
| prompt, | |
| max_new_tokens, | |
| beams, | |
| ) | |
| outputs.append((beams, caption)) | |
| return outputs | |
| processor, model, device, dtype = load_model() | |
| def run_inference( | |
| image: Image.Image, | |
| prompt: str, | |
| max_new_tokens: int, | |
| beam_mode: str, | |
| single_beam: int, | |
| compare_beams: List[str], | |
| ) -> str: | |
| if image is None: | |
| raise gr.Error("Please upload an image first.") | |
| clean_prompt = (prompt or "").strip() or DEFAULT_PROMPT | |
| if beam_mode == "Single": | |
| beam_list = [int(single_beam or 4)] | |
| else: | |
| default_options = [2, 4, 6] | |
| if not compare_beams: | |
| beam_list = default_options | |
| else: | |
| deduped = [] | |
| for value in compare_beams: | |
| beam = int(value) | |
| if beam not in deduped: | |
| deduped.append(beam) | |
| if len(deduped) == 4: | |
| break | |
| beam_list = deduped or default_options | |
| results = batched_predictions( | |
| processor, | |
| model, | |
| device, | |
| dtype, | |
| image.convert("RGB"), | |
| clean_prompt, | |
| max_new_tokens, | |
| beam_list, | |
| ) | |
| blocks = [] | |
| for beams, text in results: | |
| blocks.append(f"**Beam width {beams}**\n{text}") | |
| return "\n\n".join(blocks) | |
| def update_beam_visibility(choice: str): | |
| single_visible = choice == "Single" | |
| compare_visible = choice == "Compare" | |
| return ( | |
| gr.Slider.update(visible=single_visible), | |
| gr.CheckboxGroup.update(visible=compare_visible), | |
| ) | |
| def show_spinner(): | |
| return gr.HTML.update(visible=True) | |
| def hide_spinner(): | |
| return gr.HTML.update(visible=False) | |
| with gr.Blocks(title="BLIP-2 Image Captioning") as demo: | |
| gr.Markdown("# BLIP-2 Image Captioning (H200 fine-tuned)") | |
| gr.Markdown( | |
| "Upload an image, tweak decoding settings, and optionally compare beam widths side by side." | |
| ) | |
| gr.HTML(SPINNER_CONTAINER_CSS, show_label=False) | |
| with gr.Row(): | |
| with gr.Column(scale=6, min_width=320): | |
| image_input = gr.Image( | |
| label="Upload an image", | |
| type="pil", | |
| image_mode="RGB", | |
| ) | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| value=DEFAULT_PROMPT, | |
| lines=3, | |
| placeholder="Describe the instruction for BLIP-2", | |
| ) | |
| max_tokens_input = gr.Slider( | |
| label="Max new tokens", | |
| minimum=16, | |
| maximum=128, | |
| step=8, | |
| value=56, | |
| ) | |
| beam_mode_input = gr.Radio( | |
| label="Beam mode", | |
| choices=["Single", "Compare"], | |
| value="Single", | |
| info="Use a single beam width or compare several options simultaneously.", | |
| ) | |
| single_beam_slider = gr.Slider( | |
| label="Beam width", | |
| minimum=1, | |
| maximum=8, | |
| step=1, | |
| value=4, | |
| ) | |
| compare_beams_group = gr.CheckboxGroup( | |
| label="Select beam widths", | |
| choices=[str(i) for i in range(1, 9)], | |
| value=["2", "4", "6"], | |
| interactive=True, | |
| visible=False, | |
| ) | |
| run_button = gr.Button("Generate caption(s)") | |
| with gr.Column(scale=9): | |
| caption_output = gr.Markdown(value="Upload an image to preview captions.") | |
| gr.Markdown( | |
| f"Running inference on {device.type.upper()} with dtype {dtype}. " | |
| "Compare beams to balance diversity vs. precision." | |
| ) | |
| spinner_display = gr.HTML( | |
| value=SPINNER_MARKUP, | |
| visible=False, | |
| show_label=False, | |
| elem_id="caption-spinner", | |
| ) | |
| beam_mode_input.change( | |
| fn=update_beam_visibility, | |
| inputs=beam_mode_input, | |
| outputs=[single_beam_slider, compare_beams_group], | |
| ) | |
| run_event = run_button.click( | |
| fn=show_spinner, | |
| outputs=spinner_display, | |
| show_progress=False, | |
| ) | |
| run_event = run_event.then( | |
| fn=run_inference, | |
| inputs=[ | |
| image_input, | |
| prompt_input, | |
| max_tokens_input, | |
| beam_mode_input, | |
| single_beam_slider, | |
| compare_beams_group, | |
| ], | |
| outputs=caption_output, | |
| api_name="generate", | |
| ) | |
| run_event.then( | |
| fn=hide_spinner, | |
| outputs=spinner_display, | |
| show_progress=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |