"""OCR model orchestration — launch HF Jobs for multiple OCR models.""" from __future__ import annotations import time from dataclasses import dataclass, field import structlog from huggingface_hub import HfApi, get_token logger = structlog.get_logger() @dataclass class ModelConfig: """Configuration for a single OCR model.""" script: str model_id: str size: str default_flavor: str = "l4x1" default_args: list[str] = field(default_factory=list) MODEL_REGISTRY: dict[str, ModelConfig] = { "glm-ocr": ModelConfig( script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/glm-ocr.py", model_id="zai-org/GLM-OCR", size="0.9B", default_flavor="l4x1", ), "deepseek-ocr": ModelConfig( script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/deepseek-ocr-vllm.py", model_id="deepseek-ai/DeepSeek-OCR", size="4B", default_flavor="l4x1", default_args=["--prompt-mode", "free"], ), "lighton-ocr-2": ModelConfig( script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/lighton-ocr2.py", model_id="lightonai/LightOnOCR-2-1B", size="1B", default_flavor="a100-large", ), "dots-ocr": ModelConfig( script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/dots-ocr.py", model_id="rednote-hilab/dots.ocr", size="1.7B", default_flavor="l4x1", ), } DEFAULT_MODELS = ["glm-ocr", "deepseek-ocr", "lighton-ocr-2", "dots-ocr"] @dataclass class JobRun: """Tracks a launched HF Job.""" model_slug: str job_id: str job_url: str status: str = "running" def list_models() -> list[str]: """Return sorted list of available model slugs.""" return sorted(MODEL_REGISTRY.keys()) def build_script_args( input_dataset: str, output_repo: str, config_name: str, *, max_samples: int | None = None, shuffle: bool = False, seed: int = 42, extra_args: list[str] | None = None, ) -> list[str]: """Build the script_args list for run_uv_job.""" args = [ input_dataset, output_repo, "--config", config_name, "--create-pr", ] if max_samples is not None: args += ["--max-samples", str(max_samples)] if shuffle: args.append("--shuffle") if seed != 42: args += ["--seed", str(seed)] if extra_args: args += extra_args return args def launch_ocr_jobs( input_dataset: str, output_repo: str, *, models: list[str] | None = None, max_samples: int | None = None, split: str = "train", shuffle: bool = False, seed: int = 42, flavor_override: str | None = None, timeout: str = "4h", api: HfApi | None = None, ) -> list[JobRun]: """Launch HF Jobs for each model. Returns list of JobRun tracking objects.""" if api is None: api = HfApi() token = get_token() if not token: raise RuntimeError("No HF token found. Log in with `hf login` or set HF_TOKEN.") selected = models or DEFAULT_MODELS for slug in selected: if slug not in MODEL_REGISTRY: raise ValueError( f"Unknown model: {slug}. Available: {', '.join(MODEL_REGISTRY.keys())}" ) jobs: list[JobRun] = [] for slug in selected: config = MODEL_REGISTRY[slug] flavor = flavor_override or config.default_flavor script_args = build_script_args( input_dataset, output_repo, slug, max_samples=max_samples, shuffle=shuffle, seed=seed, extra_args=config.default_args or None, ) logger.info("launching_job", model=slug, flavor=flavor, script=config.script) job = api.run_uv_job( script=config.script, script_args=script_args, flavor=flavor, secrets={"HF_TOKEN": token}, timeout=timeout, ) jobs.append(JobRun(model_slug=slug, job_id=job.id, job_url=job.url)) logger.info("job_launched", model=slug, job_id=job.id, url=job.url) return jobs _TERMINAL_STAGES = frozenset({"COMPLETED", "ERROR", "CANCELED", "DELETED"}) def poll_jobs( jobs: list[JobRun], *, interval: int = 30, api: HfApi | None = None, ) -> list[JobRun]: """Poll until all jobs complete or fail. Updates status in-place and returns the list.""" if api is None: api = HfApi() pending = {j.job_id: j for j in jobs if j.status == "running"} while pending: time.sleep(interval) still_running: dict[str, JobRun] = {} for job_id, job_run in pending.items(): info = api.inspect_job(job_id=job_id) stage = info.status.stage if stage in _TERMINAL_STAGES: job_run.status = stage.lower() logger.info("job_finished", model=job_run.model_slug, status=job_run.status) else: still_running[job_id] = job_run pending = still_running if pending: slugs = [j.model_slug for j in pending.values()] logger.info("jobs_pending", models=slugs) return jobs