Julian Spravil
Refactor caption and translation output handling for improved clarity and efficiency
b77f03c
import gc
import os
from functools import lru_cache
from typing import Any, Optional, Tuple
import gradio as gr
import spaces
import torch
from PIL import Image
from huggingface_hub import snapshot_download
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoProcessor,
AutoTokenizer,
)
token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
assert token, "Missing HF_TOKEN (add it in Space Secrets)."
# -----------------------------
# Models
# -----------------------------
MODEL_IDS = [
"Spravil/caption-via-translation-0_4B",
"Spravil/caption-via-translation-0_4B-ft",
"Spravil/caption-via-translation-1_0B",
"Spravil/caption-via-translation-1_0B-ft",
"Spravil/caption-via-translation-3_5B",
"Spravil/caption-via-translation-3_5B-ft",
"Spravil/caption-via-translation-11_2B",
"Spravil/caption-via-translation-11_2B-ft",
]
CACHEABLE_MODEL_IDS = [
"Spravil/caption-via-translation-0_4B",
"Spravil/caption-via-translation-0_4B-ft",
"Spravil/caption-via-translation-1_0B",
"Spravil/caption-via-translation-1_0B-ft",
]
CAPTION_TASKS = [
"<CAPTION>",
"<DETAILED_CAPTION>",
"<MORE_DETAILED_CAPTION>",
]
LANGS = [
("English", "en"),
("German", "de"),
("French", "fr"),
("Spanish", "es"),
("Russian", "ru"),
("Chinese", "zh"),
]
# -----------------------------
# Runtime / device
# -----------------------------
HAS_CUDA = torch.cuda.is_available()
DEFAULT_DTYPE = torch.float16 if HAS_CUDA else torch.float32
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
CPU_MODEL_IDS = [
"Spravil/caption-via-translation-0_4B",
"Spravil/caption-via-translation-0_4B-ft",
]
AVAILABLE_MODEL_IDS = MODEL_IDS if HAS_CUDA else CPU_MODEL_IDS
DEFAULT_MODEL_ID = "Spravil/caption-via-translation-0_4B"
_hf_home = os.environ.get("HF_HOME")
HF_CACHE_DIR = os.environ.get("HF_HUB_CACHE") or (os.path.join(_hf_home, "hub") if _hf_home else None)
def _pick_decoder_tokenizer_name(model_path: str) -> str:
"""
Best-effort: infer decoder tokenizer name from config.
Fallback to gemma tokenizer (as in user's snippet).
"""
fallback = "google/gemma-2-2b"
try:
cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
for key in ["decoder_model_name_or_path", "text_model_name_or_path", "llm_model_name_or_path"]:
if hasattr(cfg, key):
v = getattr(cfg, key)
if isinstance(v, str) and v:
return v
if hasattr(cfg, "text_config") and hasattr(cfg.text_config, "_name_or_path"):
v = getattr(cfg.text_config, "_name_or_path")
if isinstance(v, str) and v:
return v
except Exception:
pass
return fallback
def _first_param_device(model: torch.nn.Module) -> torch.device:
try:
return next(model.parameters()).device
except StopIteration:
return torch.device("cpu")
@lru_cache(maxsize=16)
def _load_tokenizer(tokenizer_name: str) -> AutoTokenizer:
return AutoTokenizer.from_pretrained(
tokenizer_name,
token=token,
add_bos_token=True,
add_eos_token=True,
padding_side="right",
truncation_side="right",
)
@lru_cache(maxsize=16)
def _load_model_and_processor(model_id: str) -> Tuple[Any, Any]:
"""
Lazy-load a model + processor and cache them.
IMPORTANT: Florence2ForConditionalGeneration does NOT support device_map="auto".
"""
model_path = snapshot_download(model_id, cache_dir=HF_CACHE_DIR)
decoder_tok_name = _pick_decoder_tokenizer_name(model_path)
tokenizer = _load_tokenizer(decoder_tok_name)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if device.type == "cuda" else torch.float32
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
trust_remote_code=True,
low_cpu_mem_usage=True,
).to(device)
processor = AutoProcessor.from_pretrained(
model_path,
trust_remote_code=True,
new_tokenizer=tokenizer,
use_encoder_tokenizer=True,
)
model.eval()
return model, processor
def _prepare_image(image: Optional[Image.Image]) -> Image.Image:
if image is not None:
return image.convert("RGB")
raise gr.Error("Please upload an image.")
def _caption_prompt(lang: str, task: str) -> str:
return f"<LANG_{lang.upper()}>{task}"
def _translate_prompt(tgt_lang: str, source: str) -> str:
if not source or not source.strip():
raise gr.Error("Please provide text to translate.")
return f"<LANG_{tgt_lang.upper()}><TRANSLATE>{source.strip()}"
@spaces.GPU()
@torch.inference_mode()
def run_caption(
model_id: str,
image: Optional[Image.Image],
task: str,
lang: str,
max_new_tokens: int,
num_beams: int,
do_sample: bool,
temperature: float,
top_p: float,
use_cache: bool,
) -> str:
if model_id not in AVAILABLE_MODEL_IDS:
raise gr.Error("Selected model requires a GPU environment.")
pil_img = _prepare_image(image)
model, processor = _load_model_and_processor(model_id)
prompt = _caption_prompt(lang, task)
inputs = processor(prompt, images=pil_img, return_tensors="pt")
dev = _first_param_device(model)
dtype = DEFAULT_DTYPE if dev.type == "cuda" else torch.float32
inputs = inputs.to(dev, dtype)
gen_kwargs = dict(
max_new_tokens=int(max_new_tokens),
num_beams=int(num_beams),
do_sample=bool(do_sample),
use_cache=bool(use_cache),
)
if do_sample:
gen_kwargs.update(dict(temperature=float(temperature), top_p=float(top_p)))
generated_ids = model.generate(**inputs, **gen_kwargs)
return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
@spaces.GPU()
@torch.inference_mode()
def run_translate(
model_id: str,
source_text: str,
target_lang: str,
image: Optional[Image.Image],
max_new_tokens: int,
num_beams: int,
do_sample: bool,
temperature: float,
top_p: float,
use_cache: bool,
) -> str:
if model_id not in AVAILABLE_MODEL_IDS:
raise gr.Error("Selected model requires a GPU environment.")
model, processor = _load_model_and_processor(model_id)
prompt = _translate_prompt(target_lang, source_text)
pil_img = image.convert("RGB") if image is not None else None
if pil_img is not None:
inputs = processor(prompt, images=pil_img, return_tensors="pt")
else:
inputs = processor(prompt, return_tensors="pt")
dev = _first_param_device(model)
dtype = DEFAULT_DTYPE if dev.type == "cuda" else torch.float32
inputs = inputs.to(dev, dtype)
gen_kwargs = dict(
max_new_tokens=int(max_new_tokens),
num_beams=int(num_beams),
do_sample=bool(do_sample),
use_cache=bool(use_cache),
)
if do_sample:
gen_kwargs.update(dict(temperature=float(temperature), top_p=float(top_p)))
generated_ids = model.generate(**inputs, **gen_kwargs)
return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
def clear_caches() -> str:
_load_model_and_processor.cache_clear()
_load_tokenizer.cache_clear()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return "Cleared model/tokenizer caches and freed memory (best-effort)."
def _sync_cache_controls(
model_id: str, cap_cache_value: bool, tr_cache_value: bool
) -> Tuple[gr.Checkbox, gr.Checkbox]:
if model_id in CACHEABLE_MODEL_IDS:
return (
gr.update(value=bool(cap_cache_value), interactive=True),
gr.update(value=bool(tr_cache_value), interactive=True),
)
return (
gr.update(value=False, interactive=False),
gr.update(value=False, interactive=False),
)
# -----------------------------
# Gradio UI
# -----------------------------
with gr.Blocks(title="Caption via Translation – Space") as demo:
gr.Markdown(
"""
# Caption via Translation – Demo
Pick a model and run either captioning or translation.
""".strip()
)
with gr.Row():
model_id_global = gr.Dropdown(choices=AVAILABLE_MODEL_IDS, value=DEFAULT_MODEL_ID, label="Model")
clear_btn = gr.Button("Unload / Clear cache")
cache_status = gr.Textbox(label="Cache status", value="", interactive=False)
clear_btn.click(fn=clear_caches, inputs=[], outputs=[cache_status])
with gr.Tabs():
# -------------------------
# Caption tab
# -------------------------
with gr.Tab("Caption"):
with gr.Row():
with gr.Column(scale=1):
cap_task = gr.Dropdown(choices=CAPTION_TASKS, value="<MORE_DETAILED_CAPTION>", label="Task")
cap_lang = gr.Dropdown(choices=[v for _, v in LANGS], value="de", label="Language (LANG_XX)")
cap_image = gr.Image(type="pil", label="Upload Image")
with gr.Accordion("Generation settings", open=False):
with gr.Row():
cap_max_new = gr.Slider(16, 512, value=128, step=1, label="max_new_tokens")
cap_beams = gr.Slider(1, 8, value=4, step=1, label="num_beams")
cap_use_cache = gr.Checkbox(value=False, label="use_cache")
with gr.Row():
cap_do_sample = gr.Checkbox(value=False, label="do_sample")
cap_temp = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="temperature")
cap_top_p = gr.Slider(0.05, 1.0, value=0.9, step=0.05, label="top_p")
cap_run = gr.Button("Generate caption", variant="primary")
with gr.Column(scale=1):
cap_parsed = gr.Textbox(label="Parsed answer", lines=12)
cap_run.click(
fn=run_caption,
inputs=[
model_id_global,
cap_image,
cap_task,
cap_lang,
cap_max_new,
cap_beams,
cap_do_sample,
cap_temp,
cap_top_p,
cap_use_cache,
],
outputs=[cap_parsed],
)
# -------------------------
# Translate tab
# -------------------------
with gr.Tab("Translate"):
with gr.Row():
with gr.Column(scale=1):
tgt_lang = gr.Dropdown(choices=[v for _, v in LANGS], value="de", label="Target language (LANG_XX)")
tr_image = gr.Image(type="pil", label="Upload Image")
src_text = gr.Textbox(
label="Source text",
placeholder="Type the text to translate…",
lines=8,
)
with gr.Accordion("Generation settings", open=False):
with gr.Row():
tr_max_new = gr.Slider(8, 512, value=128, step=1, label="max_new_tokens")
tr_beams = gr.Slider(1, 8, value=4, step=1, label="num_beams")
tr_use_cache = gr.Checkbox(value=False, label="use_cache")
with gr.Row():
tr_do_sample = gr.Checkbox(value=False, label="do_sample")
tr_temp = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="temperature")
tr_top_p = gr.Slider(0.05, 1.0, value=0.9, step=0.05, label="top_p")
tr_run = gr.Button("Translate", variant="primary")
with gr.Column(scale=1):
tr_out = gr.Textbox(label="Output", lines=12)
tr_run.click(
fn=run_translate,
inputs=[
model_id_global,
src_text,
tgt_lang,
tr_image,
tr_max_new,
tr_beams,
tr_do_sample,
tr_temp,
tr_top_p,
tr_use_cache,
],
outputs=[tr_out],
)
model_id_global.change(
fn=_sync_cache_controls,
inputs=[model_id_global, cap_use_cache, tr_use_cache],
outputs=[cap_use_cache, tr_use_cache],
)
if __name__ == "__main__":
demo.launch()