Spaces:
Sleeping
Sleeping
Julian Spravil
Refactor caption and translation output handling for improved clarity and efficiency
b77f03c | import gc | |
| import os | |
| from functools import lru_cache | |
| from typing import Any, Optional, Tuple | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from PIL import Image | |
| from huggingface_hub import snapshot_download | |
| from transformers import ( | |
| AutoConfig, | |
| AutoModelForCausalLM, | |
| AutoProcessor, | |
| AutoTokenizer, | |
| ) | |
| token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") | |
| assert token, "Missing HF_TOKEN (add it in Space Secrets)." | |
| # ----------------------------- | |
| # Models | |
| # ----------------------------- | |
| MODEL_IDS = [ | |
| "Spravil/caption-via-translation-0_4B", | |
| "Spravil/caption-via-translation-0_4B-ft", | |
| "Spravil/caption-via-translation-1_0B", | |
| "Spravil/caption-via-translation-1_0B-ft", | |
| "Spravil/caption-via-translation-3_5B", | |
| "Spravil/caption-via-translation-3_5B-ft", | |
| "Spravil/caption-via-translation-11_2B", | |
| "Spravil/caption-via-translation-11_2B-ft", | |
| ] | |
| CACHEABLE_MODEL_IDS = [ | |
| "Spravil/caption-via-translation-0_4B", | |
| "Spravil/caption-via-translation-0_4B-ft", | |
| "Spravil/caption-via-translation-1_0B", | |
| "Spravil/caption-via-translation-1_0B-ft", | |
| ] | |
| CAPTION_TASKS = [ | |
| "<CAPTION>", | |
| "<DETAILED_CAPTION>", | |
| "<MORE_DETAILED_CAPTION>", | |
| ] | |
| LANGS = [ | |
| ("English", "en"), | |
| ("German", "de"), | |
| ("French", "fr"), | |
| ("Spanish", "es"), | |
| ("Russian", "ru"), | |
| ("Chinese", "zh"), | |
| ] | |
| # ----------------------------- | |
| # Runtime / device | |
| # ----------------------------- | |
| HAS_CUDA = torch.cuda.is_available() | |
| DEFAULT_DTYPE = torch.float16 if HAS_CUDA else torch.float32 | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| CPU_MODEL_IDS = [ | |
| "Spravil/caption-via-translation-0_4B", | |
| "Spravil/caption-via-translation-0_4B-ft", | |
| ] | |
| AVAILABLE_MODEL_IDS = MODEL_IDS if HAS_CUDA else CPU_MODEL_IDS | |
| DEFAULT_MODEL_ID = "Spravil/caption-via-translation-0_4B" | |
| _hf_home = os.environ.get("HF_HOME") | |
| HF_CACHE_DIR = os.environ.get("HF_HUB_CACHE") or (os.path.join(_hf_home, "hub") if _hf_home else None) | |
| def _pick_decoder_tokenizer_name(model_path: str) -> str: | |
| """ | |
| Best-effort: infer decoder tokenizer name from config. | |
| Fallback to gemma tokenizer (as in user's snippet). | |
| """ | |
| fallback = "google/gemma-2-2b" | |
| try: | |
| cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True) | |
| for key in ["decoder_model_name_or_path", "text_model_name_or_path", "llm_model_name_or_path"]: | |
| if hasattr(cfg, key): | |
| v = getattr(cfg, key) | |
| if isinstance(v, str) and v: | |
| return v | |
| if hasattr(cfg, "text_config") and hasattr(cfg.text_config, "_name_or_path"): | |
| v = getattr(cfg.text_config, "_name_or_path") | |
| if isinstance(v, str) and v: | |
| return v | |
| except Exception: | |
| pass | |
| return fallback | |
| def _first_param_device(model: torch.nn.Module) -> torch.device: | |
| try: | |
| return next(model.parameters()).device | |
| except StopIteration: | |
| return torch.device("cpu") | |
| def _load_tokenizer(tokenizer_name: str) -> AutoTokenizer: | |
| return AutoTokenizer.from_pretrained( | |
| tokenizer_name, | |
| token=token, | |
| add_bos_token=True, | |
| add_eos_token=True, | |
| padding_side="right", | |
| truncation_side="right", | |
| ) | |
| def _load_model_and_processor(model_id: str) -> Tuple[Any, Any]: | |
| """ | |
| Lazy-load a model + processor and cache them. | |
| IMPORTANT: Florence2ForConditionalGeneration does NOT support device_map="auto". | |
| """ | |
| model_path = snapshot_download(model_id, cache_dir=HF_CACHE_DIR) | |
| decoder_tok_name = _pick_decoder_tokenizer_name(model_path) | |
| tokenizer = _load_tokenizer(decoder_tok_name) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| torch_dtype = torch.float16 if device.type == "cuda" else torch.float32 | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| torch_dtype=torch_dtype, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| ).to(device) | |
| processor = AutoProcessor.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| new_tokenizer=tokenizer, | |
| use_encoder_tokenizer=True, | |
| ) | |
| model.eval() | |
| return model, processor | |
| def _prepare_image(image: Optional[Image.Image]) -> Image.Image: | |
| if image is not None: | |
| return image.convert("RGB") | |
| raise gr.Error("Please upload an image.") | |
| def _caption_prompt(lang: str, task: str) -> str: | |
| return f"<LANG_{lang.upper()}>{task}" | |
| def _translate_prompt(tgt_lang: str, source: str) -> str: | |
| if not source or not source.strip(): | |
| raise gr.Error("Please provide text to translate.") | |
| return f"<LANG_{tgt_lang.upper()}><TRANSLATE>{source.strip()}" | |
| def run_caption( | |
| model_id: str, | |
| image: Optional[Image.Image], | |
| task: str, | |
| lang: str, | |
| max_new_tokens: int, | |
| num_beams: int, | |
| do_sample: bool, | |
| temperature: float, | |
| top_p: float, | |
| use_cache: bool, | |
| ) -> str: | |
| if model_id not in AVAILABLE_MODEL_IDS: | |
| raise gr.Error("Selected model requires a GPU environment.") | |
| pil_img = _prepare_image(image) | |
| model, processor = _load_model_and_processor(model_id) | |
| prompt = _caption_prompt(lang, task) | |
| inputs = processor(prompt, images=pil_img, return_tensors="pt") | |
| dev = _first_param_device(model) | |
| dtype = DEFAULT_DTYPE if dev.type == "cuda" else torch.float32 | |
| inputs = inputs.to(dev, dtype) | |
| gen_kwargs = dict( | |
| max_new_tokens=int(max_new_tokens), | |
| num_beams=int(num_beams), | |
| do_sample=bool(do_sample), | |
| use_cache=bool(use_cache), | |
| ) | |
| if do_sample: | |
| gen_kwargs.update(dict(temperature=float(temperature), top_p=float(top_p))) | |
| generated_ids = model.generate(**inputs, **gen_kwargs) | |
| return processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| def run_translate( | |
| model_id: str, | |
| source_text: str, | |
| target_lang: str, | |
| image: Optional[Image.Image], | |
| max_new_tokens: int, | |
| num_beams: int, | |
| do_sample: bool, | |
| temperature: float, | |
| top_p: float, | |
| use_cache: bool, | |
| ) -> str: | |
| if model_id not in AVAILABLE_MODEL_IDS: | |
| raise gr.Error("Selected model requires a GPU environment.") | |
| model, processor = _load_model_and_processor(model_id) | |
| prompt = _translate_prompt(target_lang, source_text) | |
| pil_img = image.convert("RGB") if image is not None else None | |
| if pil_img is not None: | |
| inputs = processor(prompt, images=pil_img, return_tensors="pt") | |
| else: | |
| inputs = processor(prompt, return_tensors="pt") | |
| dev = _first_param_device(model) | |
| dtype = DEFAULT_DTYPE if dev.type == "cuda" else torch.float32 | |
| inputs = inputs.to(dev, dtype) | |
| gen_kwargs = dict( | |
| max_new_tokens=int(max_new_tokens), | |
| num_beams=int(num_beams), | |
| do_sample=bool(do_sample), | |
| use_cache=bool(use_cache), | |
| ) | |
| if do_sample: | |
| gen_kwargs.update(dict(temperature=float(temperature), top_p=float(top_p))) | |
| generated_ids = model.generate(**inputs, **gen_kwargs) | |
| return processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| def clear_caches() -> str: | |
| _load_model_and_processor.cache_clear() | |
| _load_tokenizer.cache_clear() | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return "Cleared model/tokenizer caches and freed memory (best-effort)." | |
| def _sync_cache_controls( | |
| model_id: str, cap_cache_value: bool, tr_cache_value: bool | |
| ) -> Tuple[gr.Checkbox, gr.Checkbox]: | |
| if model_id in CACHEABLE_MODEL_IDS: | |
| return ( | |
| gr.update(value=bool(cap_cache_value), interactive=True), | |
| gr.update(value=bool(tr_cache_value), interactive=True), | |
| ) | |
| return ( | |
| gr.update(value=False, interactive=False), | |
| gr.update(value=False, interactive=False), | |
| ) | |
| # ----------------------------- | |
| # Gradio UI | |
| # ----------------------------- | |
| with gr.Blocks(title="Caption via Translation – Space") as demo: | |
| gr.Markdown( | |
| """ | |
| # Caption via Translation – Demo | |
| Pick a model and run either captioning or translation. | |
| """.strip() | |
| ) | |
| with gr.Row(): | |
| model_id_global = gr.Dropdown(choices=AVAILABLE_MODEL_IDS, value=DEFAULT_MODEL_ID, label="Model") | |
| clear_btn = gr.Button("Unload / Clear cache") | |
| cache_status = gr.Textbox(label="Cache status", value="", interactive=False) | |
| clear_btn.click(fn=clear_caches, inputs=[], outputs=[cache_status]) | |
| with gr.Tabs(): | |
| # ------------------------- | |
| # Caption tab | |
| # ------------------------- | |
| with gr.Tab("Caption"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| cap_task = gr.Dropdown(choices=CAPTION_TASKS, value="<MORE_DETAILED_CAPTION>", label="Task") | |
| cap_lang = gr.Dropdown(choices=[v for _, v in LANGS], value="de", label="Language (LANG_XX)") | |
| cap_image = gr.Image(type="pil", label="Upload Image") | |
| with gr.Accordion("Generation settings", open=False): | |
| with gr.Row(): | |
| cap_max_new = gr.Slider(16, 512, value=128, step=1, label="max_new_tokens") | |
| cap_beams = gr.Slider(1, 8, value=4, step=1, label="num_beams") | |
| cap_use_cache = gr.Checkbox(value=False, label="use_cache") | |
| with gr.Row(): | |
| cap_do_sample = gr.Checkbox(value=False, label="do_sample") | |
| cap_temp = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="temperature") | |
| cap_top_p = gr.Slider(0.05, 1.0, value=0.9, step=0.05, label="top_p") | |
| cap_run = gr.Button("Generate caption", variant="primary") | |
| with gr.Column(scale=1): | |
| cap_parsed = gr.Textbox(label="Parsed answer", lines=12) | |
| cap_run.click( | |
| fn=run_caption, | |
| inputs=[ | |
| model_id_global, | |
| cap_image, | |
| cap_task, | |
| cap_lang, | |
| cap_max_new, | |
| cap_beams, | |
| cap_do_sample, | |
| cap_temp, | |
| cap_top_p, | |
| cap_use_cache, | |
| ], | |
| outputs=[cap_parsed], | |
| ) | |
| # ------------------------- | |
| # Translate tab | |
| # ------------------------- | |
| with gr.Tab("Translate"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| tgt_lang = gr.Dropdown(choices=[v for _, v in LANGS], value="de", label="Target language (LANG_XX)") | |
| tr_image = gr.Image(type="pil", label="Upload Image") | |
| src_text = gr.Textbox( | |
| label="Source text", | |
| placeholder="Type the text to translate…", | |
| lines=8, | |
| ) | |
| with gr.Accordion("Generation settings", open=False): | |
| with gr.Row(): | |
| tr_max_new = gr.Slider(8, 512, value=128, step=1, label="max_new_tokens") | |
| tr_beams = gr.Slider(1, 8, value=4, step=1, label="num_beams") | |
| tr_use_cache = gr.Checkbox(value=False, label="use_cache") | |
| with gr.Row(): | |
| tr_do_sample = gr.Checkbox(value=False, label="do_sample") | |
| tr_temp = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="temperature") | |
| tr_top_p = gr.Slider(0.05, 1.0, value=0.9, step=0.05, label="top_p") | |
| tr_run = gr.Button("Translate", variant="primary") | |
| with gr.Column(scale=1): | |
| tr_out = gr.Textbox(label="Output", lines=12) | |
| tr_run.click( | |
| fn=run_translate, | |
| inputs=[ | |
| model_id_global, | |
| src_text, | |
| tgt_lang, | |
| tr_image, | |
| tr_max_new, | |
| tr_beams, | |
| tr_do_sample, | |
| tr_temp, | |
| tr_top_p, | |
| tr_use_cache, | |
| ], | |
| outputs=[tr_out], | |
| ) | |
| model_id_global.change( | |
| fn=_sync_cache_controls, | |
| inputs=[model_id_global, cap_use_cache, tr_use_cache], | |
| outputs=[cap_use_cache, tr_use_cache], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |