davanstrien's picture
davanstrien HF Staff
Upload folder using huggingface_hub
1118181 verified
"""OCR model orchestration — launch HF Jobs for multiple OCR models."""
from __future__ import annotations
import time
from dataclasses import dataclass, field
import structlog
from huggingface_hub import HfApi, get_token
logger = structlog.get_logger()
@dataclass
class ModelConfig:
"""Configuration for a single OCR model."""
script: str
model_id: str
size: str
default_flavor: str = "l4x1"
default_args: list[str] = field(default_factory=list)
MODEL_REGISTRY: dict[str, ModelConfig] = {
"glm-ocr": ModelConfig(
script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/glm-ocr.py",
model_id="zai-org/GLM-OCR",
size="0.9B",
default_flavor="l4x1",
),
"deepseek-ocr": ModelConfig(
script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/deepseek-ocr-vllm.py",
model_id="deepseek-ai/DeepSeek-OCR",
size="4B",
default_flavor="l4x1",
default_args=["--prompt-mode", "free"],
),
"lighton-ocr-2": ModelConfig(
script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/lighton-ocr2.py",
model_id="lightonai/LightOnOCR-2-1B",
size="1B",
default_flavor="a100-large",
),
"dots-ocr": ModelConfig(
script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/dots-ocr.py",
model_id="rednote-hilab/dots.ocr",
size="1.7B",
default_flavor="l4x1",
),
}
DEFAULT_MODELS = ["glm-ocr", "deepseek-ocr", "lighton-ocr-2", "dots-ocr"]
@dataclass
class JobRun:
"""Tracks a launched HF Job."""
model_slug: str
job_id: str
job_url: str
status: str = "running"
def list_models() -> list[str]:
"""Return sorted list of available model slugs."""
return sorted(MODEL_REGISTRY.keys())
def build_script_args(
input_dataset: str,
output_repo: str,
config_name: str,
*,
max_samples: int | None = None,
shuffle: bool = False,
seed: int = 42,
extra_args: list[str] | None = None,
) -> list[str]:
"""Build the script_args list for run_uv_job."""
args = [
input_dataset,
output_repo,
"--config",
config_name,
"--create-pr",
]
if max_samples is not None:
args += ["--max-samples", str(max_samples)]
if shuffle:
args.append("--shuffle")
if seed != 42:
args += ["--seed", str(seed)]
if extra_args:
args += extra_args
return args
def launch_ocr_jobs(
input_dataset: str,
output_repo: str,
*,
models: list[str] | None = None,
max_samples: int | None = None,
split: str = "train",
shuffle: bool = False,
seed: int = 42,
flavor_override: str | None = None,
timeout: str = "4h",
api: HfApi | None = None,
) -> list[JobRun]:
"""Launch HF Jobs for each model. Returns list of JobRun tracking objects."""
if api is None:
api = HfApi()
token = get_token()
if not token:
raise RuntimeError("No HF token found. Log in with `hf login` or set HF_TOKEN.")
selected = models or DEFAULT_MODELS
for slug in selected:
if slug not in MODEL_REGISTRY:
raise ValueError(
f"Unknown model: {slug}. Available: {', '.join(MODEL_REGISTRY.keys())}"
)
jobs: list[JobRun] = []
for slug in selected:
config = MODEL_REGISTRY[slug]
flavor = flavor_override or config.default_flavor
script_args = build_script_args(
input_dataset,
output_repo,
slug,
max_samples=max_samples,
shuffle=shuffle,
seed=seed,
extra_args=config.default_args or None,
)
logger.info("launching_job", model=slug, flavor=flavor, script=config.script)
job = api.run_uv_job(
script=config.script,
script_args=script_args,
flavor=flavor,
secrets={"HF_TOKEN": token},
timeout=timeout,
)
jobs.append(JobRun(model_slug=slug, job_id=job.id, job_url=job.url))
logger.info("job_launched", model=slug, job_id=job.id, url=job.url)
return jobs
_TERMINAL_STAGES = frozenset({"COMPLETED", "ERROR", "CANCELED", "DELETED"})
def poll_jobs(
jobs: list[JobRun],
*,
interval: int = 30,
api: HfApi | None = None,
) -> list[JobRun]:
"""Poll until all jobs complete or fail. Updates status in-place and returns the list."""
if api is None:
api = HfApi()
pending = {j.job_id: j for j in jobs if j.status == "running"}
while pending:
time.sleep(interval)
still_running: dict[str, JobRun] = {}
for job_id, job_run in pending.items():
info = api.inspect_job(job_id=job_id)
stage = info.status.stage
if stage in _TERMINAL_STAGES:
job_run.status = stage.lower()
logger.info("job_finished", model=job_run.model_slug, status=job_run.status)
else:
still_running[job_id] = job_run
pending = still_running
if pending:
slugs = [j.model_slug for j in pending.values()]
logger.info("jobs_pending", models=slugs)
return jobs