File size: 5,259 Bytes
1118181 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | """OCR model orchestration — launch HF Jobs for multiple OCR models."""
from __future__ import annotations
import time
from dataclasses import dataclass, field
import structlog
from huggingface_hub import HfApi, get_token
logger = structlog.get_logger()
@dataclass
class ModelConfig:
"""Configuration for a single OCR model."""
script: str
model_id: str
size: str
default_flavor: str = "l4x1"
default_args: list[str] = field(default_factory=list)
MODEL_REGISTRY: dict[str, ModelConfig] = {
"glm-ocr": ModelConfig(
script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/glm-ocr.py",
model_id="zai-org/GLM-OCR",
size="0.9B",
default_flavor="l4x1",
),
"deepseek-ocr": ModelConfig(
script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/deepseek-ocr-vllm.py",
model_id="deepseek-ai/DeepSeek-OCR",
size="4B",
default_flavor="l4x1",
default_args=["--prompt-mode", "free"],
),
"lighton-ocr-2": ModelConfig(
script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/lighton-ocr2.py",
model_id="lightonai/LightOnOCR-2-1B",
size="1B",
default_flavor="a100-large",
),
"dots-ocr": ModelConfig(
script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/dots-ocr.py",
model_id="rednote-hilab/dots.ocr",
size="1.7B",
default_flavor="l4x1",
),
}
DEFAULT_MODELS = ["glm-ocr", "deepseek-ocr", "lighton-ocr-2", "dots-ocr"]
@dataclass
class JobRun:
"""Tracks a launched HF Job."""
model_slug: str
job_id: str
job_url: str
status: str = "running"
def list_models() -> list[str]:
"""Return sorted list of available model slugs."""
return sorted(MODEL_REGISTRY.keys())
def build_script_args(
input_dataset: str,
output_repo: str,
config_name: str,
*,
max_samples: int | None = None,
shuffle: bool = False,
seed: int = 42,
extra_args: list[str] | None = None,
) -> list[str]:
"""Build the script_args list for run_uv_job."""
args = [
input_dataset,
output_repo,
"--config",
config_name,
"--create-pr",
]
if max_samples is not None:
args += ["--max-samples", str(max_samples)]
if shuffle:
args.append("--shuffle")
if seed != 42:
args += ["--seed", str(seed)]
if extra_args:
args += extra_args
return args
def launch_ocr_jobs(
input_dataset: str,
output_repo: str,
*,
models: list[str] | None = None,
max_samples: int | None = None,
split: str = "train",
shuffle: bool = False,
seed: int = 42,
flavor_override: str | None = None,
timeout: str = "4h",
api: HfApi | None = None,
) -> list[JobRun]:
"""Launch HF Jobs for each model. Returns list of JobRun tracking objects."""
if api is None:
api = HfApi()
token = get_token()
if not token:
raise RuntimeError("No HF token found. Log in with `hf login` or set HF_TOKEN.")
selected = models or DEFAULT_MODELS
for slug in selected:
if slug not in MODEL_REGISTRY:
raise ValueError(
f"Unknown model: {slug}. Available: {', '.join(MODEL_REGISTRY.keys())}"
)
jobs: list[JobRun] = []
for slug in selected:
config = MODEL_REGISTRY[slug]
flavor = flavor_override or config.default_flavor
script_args = build_script_args(
input_dataset,
output_repo,
slug,
max_samples=max_samples,
shuffle=shuffle,
seed=seed,
extra_args=config.default_args or None,
)
logger.info("launching_job", model=slug, flavor=flavor, script=config.script)
job = api.run_uv_job(
script=config.script,
script_args=script_args,
flavor=flavor,
secrets={"HF_TOKEN": token},
timeout=timeout,
)
jobs.append(JobRun(model_slug=slug, job_id=job.id, job_url=job.url))
logger.info("job_launched", model=slug, job_id=job.id, url=job.url)
return jobs
_TERMINAL_STAGES = frozenset({"COMPLETED", "ERROR", "CANCELED", "DELETED"})
def poll_jobs(
jobs: list[JobRun],
*,
interval: int = 30,
api: HfApi | None = None,
) -> list[JobRun]:
"""Poll until all jobs complete or fail. Updates status in-place and returns the list."""
if api is None:
api = HfApi()
pending = {j.job_id: j for j in jobs if j.status == "running"}
while pending:
time.sleep(interval)
still_running: dict[str, JobRun] = {}
for job_id, job_run in pending.items():
info = api.inspect_job(job_id=job_id)
stage = info.status.stage
if stage in _TERMINAL_STAGES:
job_run.status = stage.lower()
logger.info("job_finished", model=job_run.model_slug, status=job_run.status)
else:
still_running[job_id] = job_run
pending = still_running
if pending:
slugs = [j.model_slug for j in pending.values()]
logger.info("jobs_pending", models=slugs)
return jobs
|