yolov11-ls-backend / train_inline.py
davanstrien's picture
davanstrien HF Staff
Drop deprecated 'message' kwarg from create_tag
f12c0b8 verified
"""
In-Space training cycle for the yolov11 newspaper detector (Pattern C-inline).
Differs from `train_cycle_from_ls.py` (the Job orchestrator) in two ways:
- Designed to be imported as a module and called from app.py, not run as a uv script.
- No /reload callback — the caller swaps the model in-process.
When to use this vs the Job pattern:
- Inline: model + training data fit in the Space's GPU memory; iteration speed matters.
- Job: training needs a beefier flavor than the Space, or the Space should stay cheap.
"""
from __future__ import annotations
import hashlib
import io
import logging
import os
import shutil
from collections import Counter
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import requests
import yaml
from datasets import Dataset, Features, Image as DSImage, Sequence, Value
from huggingface_hub import HfApi, create_repo, hf_hub_download, upload_file
from PIL import Image
from ultralytics import YOLO
log = logging.getLogger("train_inline")
# ---- helpers ---------------------------------------------------------------
def assign_split(item_id: str, val_ratio: float = 0.2, namespace: str = "") -> str:
h = hashlib.sha256(f"{namespace}::{item_id}".encode("utf-8")).digest()
bucket = int.from_bytes(h[:8], "big") % 1000
return "val" if bucket < val_ratio * 1000 else "train"
def access_token(ls_url: str, refresh: str) -> str:
r = requests.post(f"{ls_url}/api/token/refresh/", json={"refresh": refresh}, timeout=30)
r.raise_for_status()
return r.json()["access"]
def fetch_tasks(ls_url: str, access: str, project_id: int) -> list[dict]:
out: list[dict] = []
page = 1
while True:
r = requests.get(
f"{ls_url}/api/tasks",
headers={"Authorization": f"Bearer {access}"},
params={"project": project_id, "page": page, "page_size": 100, "fields": "all"},
timeout=60,
)
# LS returns 404 when paging past the end (instead of an empty list)
if r.status_code == 404:
break
r.raise_for_status()
data = r.json()
tasks = data.get("tasks", data) if isinstance(data, dict) else data
if not tasks:
break
out.extend(tasks)
if len(tasks) < 100:
break
page += 1
return out
def ls_box_to_xyxy(value: dict, W: int, H: int) -> tuple[float, float, float, float]:
x1 = value["x"] / 100.0 * W
y1 = value["y"] / 100.0 * H
x2 = x1 + value["width"] / 100.0 * W
y2 = y1 + value["height"] / 100.0 * H
return x1, y1, x2, y2
def materialise_split(rows: list[dict[str, Any]], split_name: str, dest: Path, class_to_id: dict[str, int]) -> int:
img_dir = dest / "images" / split_name
lbl_dir = dest / "labels" / split_name
img_dir.mkdir(parents=True, exist_ok=True)
lbl_dir.mkdir(parents=True, exist_ok=True)
n_kept = 0
for i, r in enumerate(rows):
img: Image.Image = r["image"]
if img.mode != "RGB":
img = img.convert("RGB")
W, H = img.size
if W == 0 or H == 0:
continue
stem = f"{split_name}_{i:05d}"
img.save(img_dir / f"{stem}.jpg", format="JPEG", quality=90)
lines = []
for box, label in zip(r["boxes"], r["labels"]):
if label not in class_to_id:
continue
cls = class_to_id[label]
x1, y1, x2, y2 = box
cx = (x1 + x2) / 2.0 / W
cy = (y1 + y2) / 2.0 / H
w = (x2 - x1) / W
h = (y2 - y1) / H
if w <= 0 or h <= 0:
continue
lines.append(f"{cls} {cx:.6f} {cy:.6f} {w:.6f} {h:.6f}")
(lbl_dir / f"{stem}.txt").write_text("\n".join(lines))
n_kept += 1
log.info("materialised %d %s items", n_kept, split_name)
return n_kept
# ---- main entrypoint -------------------------------------------------------
def run_training_cycle(
*,
ls_url: str,
refresh_token: str,
project_id: int,
output_repo: str,
train_repo: str = "davanstrien/newspaper-detector-train-v0",
val_repo: str = "davanstrien/newspaper-detector-val-v0",
baseline_repo: str = "small-models-for-glam/historic-newspaper-illustrations-yolov11",
baseline_weights: str = "yolo11s.pt",
classes: list[str] | None = None,
val_ratio: float = 0.2,
namespace: str = "newspaper-detector-v0",
epochs: int = 15,
imgsz: int = 1024,
batch: int = 8,
lr0: float = 1e-3,
workdir: str = "/tmp/yolo_workdir",
) -> dict:
"""Run a full Pattern C training cycle inline. Returns metadata about the new revision.
Raises on failure. Caller is responsible for state tracking + model swap.
"""
if classes is None:
classes = ["Illustration", "Photograph"]
class_to_id = {c: i for i, c in enumerate(classes)}
log.info("run_training_cycle: project=%s classes=%s", project_id, class_to_id)
log.info("Step 1/6: Pull LS annotations")
access = access_token(ls_url, refresh_token)
tasks = fetch_tasks(ls_url, access, project_id)
log.info("Got %d tasks", len(tasks))
train_examples: list[dict[str, Any]] = []
val_examples: list[dict[str, Any]] = []
log.info("Step 2/6: Apply online val split + collect reviewed pages")
for t in tasks:
annotations = t.get("annotations", [])
if not annotations:
continue
ann = max(annotations, key=lambda a: a.get("updated_at", a.get("created_at", "")))
if ann.get("was_cancelled"):
continue
data = t.get("data", {})
image_url = data.get("image")
page_id = data.get("id") or data.get("filename") or str(t.get("id"))
try:
r = requests.get(image_url, timeout=60)
r.raise_for_status()
img = Image.open(io.BytesIO(r.content))
if img.mode != "RGB":
img = img.convert("RGB")
W, H = img.size
except Exception as e:
log.warning("image fetch failed for %s: %s", page_id, e)
continue
boxes: list[list[float]] = []
labels: list[str] = []
for res in ann.get("result", []):
if res.get("type") != "rectanglelabels":
continue
v = res.get("value", {})
x1, y1, x2, y2 = ls_box_to_xyxy(v, W, H)
boxes.append([x1, y1, x2, y2])
labels.append(v.get("rectanglelabels", ["unknown"])[0])
rec = {
"id": page_id, "image": img,
"image_width": W, "image_height": H,
"publication": data.get("publication", ""),
"year": str(data.get("year", "")),
"title": data.get("title", ""),
"language": data.get("language", ""),
"boxes": boxes, "labels": labels,
"ls_task_id": t.get("id"),
"ls_annotation_id": ann.get("id"),
}
split = assign_split(page_id, val_ratio=val_ratio, namespace=namespace)
if split == "val":
val_examples.append(rec)
else:
train_examples.append(rec)
log.info("Routed: train=%d, val=%d", len(train_examples), len(val_examples))
if not train_examples:
raise RuntimeError("No train examples — need reviewed pages with boxes first.")
log.info("Step 3/6: Push train + val Hub datasets")
features = Features({
"id": Value("string"), "image": DSImage(),
"image_width": Value("int32"), "image_height": Value("int32"),
"publication": Value("string"), "year": Value("string"),
"title": Value("string"), "language": Value("string"),
"boxes": Sequence(Sequence(Value("float32"), length=4)),
"labels": Sequence(Value("string")),
"ls_task_id": Value("int64"),
"ls_annotation_id": Value("int64"),
})
iso = datetime.now(timezone.utc).isoformat()
train_ds = Dataset.from_list(train_examples, features=features)
create_repo(train_repo, repo_type="dataset", exist_ok=True, private=False)
train_ds.push_to_hub(train_repo, private=False, commit_message=f"cycle {iso}: {len(train_examples)} train")
if val_examples:
val_ds = Dataset.from_list(val_examples, features=features)
create_repo(val_repo, repo_type="dataset", exist_ok=True, private=False)
val_ds.push_to_hub(val_repo, private=False, commit_message=f"cycle {iso}: {len(val_examples)} val")
log.info("Step 4/6: Materialise YOLO format")
wd = Path(workdir).resolve()
if wd.exists():
shutil.rmtree(wd)
wd.mkdir(parents=True)
n_train = materialise_split(train_examples, "train", wd, class_to_id)
n_val = materialise_split(val_examples, "val", wd, class_to_id) if val_examples else 0
label_counts: Counter = Counter()
for r in train_examples:
for l in r["labels"]:
if l in class_to_id:
label_counts[l] += 1
log.info("Train label distribution: %s", dict(label_counts))
data_yaml = {
"path": str(wd),
"train": "images/train",
"val": "images/val" if n_val > 0 else "images/train",
"names": {i: c for c, i in class_to_id.items()},
"nc": len(class_to_id),
}
yaml_path = wd / "data.yaml"
yaml_path.write_text(yaml.safe_dump(data_yaml))
log.info("Step 5/6: Fine-tune (epochs=%d, imgsz=%d, batch=%d, lr0=%.4g)", epochs, imgsz, batch, lr0)
weights_path = hf_hub_download(repo_id=baseline_repo, filename=baseline_weights)
model = YOLO(weights_path)
model.train(
data=str(yaml_path),
epochs=epochs, imgsz=imgsz, batch=batch, lr0=lr0,
project=str(wd / "runs"), name="ft",
verbose=True,
)
best = wd / "runs" / "ft" / "weights" / "best.pt"
if not best.exists():
candidates = list((wd / "runs").rglob("best.pt"))
if not candidates:
raise RuntimeError(f"no best.pt found in {wd}")
best = candidates[0]
log.info("Best weights: %s (%.1f MB)", best, best.stat().st_size / 1e6)
log.info("Step 6/6: Push fine-tuned model")
version_tag = datetime.now(timezone.utc).strftime("v%Y%m%d-%H%M%S")
create_repo(output_repo, repo_type="model", exist_ok=True, private=False)
card = f"""---
license: agpl-3.0
base_model:
- {baseline_repo}
datasets:
- {train_repo}
- {val_repo}
tags:
- object-detection
- newspapers
- historical-documents
- active-learning
- domain-adaptation
pipeline_tag: object-detection
---
# Europeana Newspaper Detector — Pattern C inline revision {version_tag}
Pattern C (inline) cycle from `{baseline_repo}` → 1890s-1930s European newspapers,
fine-tuned on human-reviewed annotations via the
[bootstrap-labels-skill](https://huggingface.co/spaces/davanstrien/yolov11-ls-backend) Space.
Trained inline in the Space (t4-small) — model is small enough that Job orchestration
overhead isn't worth paying for. See `references/active_learning_loop.md` for when to use
inline vs Job pattern.
## Cycle
- Revision: `{version_tag}`
- Train: {n_train} pages (`{train_repo}`)
- Val: {n_val} pages (`{val_repo}`)
- Classes: {classes}
- Train label distribution: {dict(label_counts)}
## Hyperparameters
- Epochs: {epochs}
- Image size: {imgsz}
- Batch: {batch}
- LR: {lr0}
"""
(wd / "README.md").write_text(card)
api = HfApi()
upload_file(
path_or_fileobj=str(best), path_in_repo="best.pt",
repo_id=output_repo, repo_type="model",
commit_message=f"Pattern C inline cycle {version_tag}: {n_train} train + {n_val} val",
)
upload_file(
path_or_fileobj=str(wd / "README.md"), path_in_repo="README.md",
repo_id=output_repo, repo_type="model",
commit_message=f"Card for {version_tag}",
)
try:
api.create_tag(repo_id=output_repo, repo_type="model", tag=version_tag)
except Exception as e:
log.warning("tag create failed: %s", e)
return {
"version": version_tag,
"repo": output_repo,
"weights_path": str(best),
"n_train": n_train,
"n_val": n_val,
"label_counts": dict(label_counts),
}