Spaces:
Running
Running
| """ | |
| OpenEnv GRPO training (production-oriented, simple stack). | |
| Produces real training artifacts (trainer log_history, metrics JSON, reward plots) and | |
| optional Hub push of LoRA weights. Every execution reward calls your live Space (or | |
| local server) at OPENENV_BASE_URL β not a mock. | |
| Environment (control cost vs quality on HF Jobs / local GPU): | |
| OPENENV_BASE_URL β OpenEnv HTTP root (default: Space URL from openenv.yaml) | |
| OPENENV_TASK_IDS β Comma list; if unset, uses GET /tasks from the server | |
| ROWS_PER_TASK β GRPO rows per task_id (default: 48) | |
| OPENENV_REQUEST_TIMEOUT_SEC β HTTP timeout for reset/step (default: 120) | |
| TRAIN_MAX_STEPS β GRPO steps (default 200) | |
| TRL_REPORT_TO β none | wandb | tensorboard (auto: wandb if key else none) | |
| BOOTSTRAP_*_VERSION β pin transformers / accelerate / trl (defaults satisfy trl>=4.50) | |
| Artifacts: artifacts/train_log_history.jsonl, metrics, plots | |
| HF_HUB_REPO_ID β push target (default md896/sota-sql-agent-7b) | |
| SKIP_HUB_PUSH=1 β do not push after train | |
| HF_TOKEN / HUGGING_FACE_HUB_TOKEN β Hub auth for push | |
| Designed for Hugging Face Jobs / Spaces where: | |
| - system Python may be externally managed (PEP-668) β uses --break-system-packages | |
| - preinstalled CUDA/PyTorch stacks can conflict with optional vision packages | |
| Key stability choices: | |
| - Avoid importing torchvision in text-only runs (it can break when torch/torchvision | |
| versions are mismatched by dependency resolution). | |
| - Produce plots and metrics from the *actual* GRPO run (no hard-coded scores). | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import random | |
| import subprocess | |
| import sys | |
| import time | |
| import uuid | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| def _run(cmd: List[str], *, check: bool = True) -> subprocess.CompletedProcess: | |
| return subprocess.run(cmd, check=check) | |
| def _pip(args: List[str], *, check: bool = True) -> subprocess.CompletedProcess: | |
| return _run([sys.executable, "-m", "pip", *args], check=check) | |
| def bootstrap_deps() -> None: | |
| """ | |
| Best-effort dependency bootstrap for ephemeral HF containers. | |
| Set SKIP_BOOTSTRAP=1 to disable. | |
| Pins: BOOTSTRAP_TRANSFORMERS_VERSION, BOOTSTRAP_ACCELERATE_VERSION, BOOTSTRAP_TRL_VERSION. | |
| """ | |
| if os.environ.get("SKIP_BOOTSTRAP") == "1": | |
| return | |
| # Ensure text-only transformers runs never hard-import torchvision even if it | |
| # is present in the base image. | |
| os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1") | |
| # Ubuntu 24.04+ images may mark system Python as "externally managed" | |
| # (PEP-668). Prefer an explicit opt-out for all pip ops in ephemeral jobs. | |
| os.environ.setdefault("PIP_BREAK_SYSTEM_PACKAGES", "1") | |
| print("Bootstrapping dependencies...") | |
| # Text-only run: torchvision/torchaudio are not required and are a common source | |
| # of crashes when torch versions shift in container images. | |
| _pip(["uninstall", "--break-system-packages", "-y", "torchvision", "torchaudio"], check=False) | |
| _pip(["uninstall", "-y", "torchao"], check=False) | |
| # trl 0.18.x needs transformers>=4.50. datasets 4.x pulls huggingface-hub 1.x which breaks 4.5x. | |
| _tf = os.environ.get("BOOTSTRAP_TRANSFORMERS_VERSION", "4.51.3") | |
| _acc = os.environ.get("BOOTSTRAP_ACCELERATE_VERSION", "0.34.2") | |
| _trl = os.environ.get("BOOTSTRAP_TRL_VERSION", "0.18.2") | |
| _pip( | |
| [ | |
| "install", | |
| "--break-system-packages", | |
| "httpx>=0.27.0", | |
| "datasets>=3.2.0,<4.0.0", | |
| "matplotlib", | |
| "tensorboard", | |
| f"transformers=={_tf}", | |
| f"accelerate=={_acc}", | |
| f"trl=={_trl}", | |
| ] | |
| ) | |
| if os.environ.get("WANDB_API_KEY"): | |
| _pip(["install", "--break-system-packages", "wandb"], check=False) | |
| _pip( | |
| [ | |
| "install", | |
| "--break-system-packages", | |
| "--force-reinstall", | |
| "--no-deps", | |
| f"transformers=={_tf}", | |
| f"accelerate=={_acc}", | |
| ] | |
| ) | |
| _pip(["install", "--break-system-packages", "--no-deps", f"trl=={_trl}"]) | |
| _pip(["uninstall", "-y", "torchao"], check=False) | |
| _pip(["uninstall", "--break-system-packages", "-y", "torchvision", "torchaudio"], check=False) | |
| # Keep bootstrap import-free; training imports happen below. | |
| bootstrap_deps() | |
| import httpx | |
| import torch | |
| from datasets import Dataset | |
| from huggingface_hub import HfApi | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from trl import GRPOConfig, GRPOTrainer | |
| # --- 1. CONFIGURATION (env-first; defaults match openenv.yaml) --- | |
| _DEFAULT_OPENENV_BASE = "https://md896-sql-debug-env.hf.space" | |
| BYPASS_HEADERS: Dict[str, str] = {} | |
| MODEL_NAME = os.environ.get("TRAIN_MODEL_NAME", "Qwen/Qwen2.5-Coder-0.5B-Instruct") | |
| def get_bridge_url() -> str: | |
| return os.environ.get("OPENENV_BASE_URL", _DEFAULT_OPENENV_BASE).rstrip("/") | |
| def get_request_timeout() -> float: | |
| return float(os.environ.get("OPENENV_REQUEST_TIMEOUT_SEC", "120")) | |
| def _fetch_task_ids(client: httpx.Client) -> List[str]: | |
| raw = os.environ.get("OPENENV_TASK_IDS", "").strip() | |
| if raw: | |
| return [x.strip() for x in raw.split(",") if x.strip()] | |
| r = client.get("/tasks", timeout=get_request_timeout()) | |
| r.raise_for_status() | |
| body = r.json() | |
| tasks = body.get("tasks") or [] | |
| ids = [t["task_id"] for t in tasks if t.get("task_id")] | |
| if not ids: | |
| raise RuntimeError("/tasks returned no task_id entries") | |
| return ids | |
| def make_real_dataset() -> Dataset: | |
| """Plain prompts + live /tasks (same spirit as colab_real_world.py, HF Space instead of loca.lt).""" | |
| bridge = get_bridge_url() | |
| timeout = get_request_timeout() | |
| rows_per_task = max(1, int(os.environ.get("ROWS_PER_TASK", "48"))) | |
| marker = os.environ.get("COMPLETION_SQL_MARKER", "Fixed SQL:") | |
| print(f"Connecting to OpenEnv at {bridge} (timeout={timeout}s)...") | |
| rows: List[Dict[str, Any]] = [] | |
| with httpx.Client(base_url=bridge, headers=BYPASS_HEADERS, timeout=timeout) as client: | |
| h = client.get("/health", timeout=min(30.0, timeout)) | |
| h.raise_for_status() | |
| print(f"OpenEnv health: {h.json()}") | |
| task_ids = _fetch_task_ids(client) | |
| print(f"Training task_ids ({len(task_ids)}): {task_ids}") | |
| for t_id in task_ids: | |
| resp = client.post("/reset", json={"task_id": t_id}) | |
| resp.raise_for_status() | |
| obs = resp.json()["observation"] | |
| prompt = ( | |
| "Fix the following SQL query and provide only the fixed SQL.\n" | |
| f"Task: {obs['task_description']}\n" | |
| f"Broken Query: {obs['original_query']}\n" | |
| f"{marker}" | |
| ) | |
| for _ in range(rows_per_task): | |
| rows.append({"prompt": prompt, "task_id": t_id}) | |
| if not rows: | |
| raise RuntimeError("Failed to build dataset (no rows).") | |
| print(f"Dataset: {len(rows)} prompts ({rows_per_task} per task).") | |
| return Dataset.from_list(rows) | |
| def make_task_dataset(task_id: str, rows_per_task: int) -> Dataset: | |
| bridge = get_bridge_url() | |
| timeout = get_request_timeout() | |
| marker = os.environ.get("COMPLETION_SQL_MARKER", "Fixed SQL:") | |
| with httpx.Client(base_url=bridge, headers=BYPASS_HEADERS, timeout=timeout) as client: | |
| resp = client.post("/reset", json={"task_id": task_id}) | |
| resp.raise_for_status() | |
| obs = resp.json()["observation"] | |
| prompt = ( | |
| "Fix the following SQL query and provide only the fixed SQL.\n" | |
| f"Task: {obs['task_description']}\n" | |
| f"Broken Query: {obs['original_query']}\n" | |
| f"{marker}" | |
| ) | |
| rows = [{"prompt": prompt, "task_id": task_id} for _ in range(max(1, rows_per_task))] | |
| return Dataset.from_list(rows) | |
| # --- 3. One live OpenEnv reward (colab_real_world style) --- | |
| def openenv_sql_reward_func(completions, task_id, **kwargs): | |
| """Score completions by executing extracted SQL against the real OpenEnv HTTP API.""" | |
| base = get_bridge_url() | |
| timeout = get_request_timeout() | |
| marker = os.environ.get("COMPLETION_SQL_MARKER", "Fixed SQL:") | |
| rewards: List[float] = [] | |
| with httpx.Client(base_url=base, headers=BYPASS_HEADERS, timeout=timeout) as client: | |
| for completion, t_id in zip(completions, task_id): | |
| if marker in completion: | |
| sql = completion.split(marker, 1)[-1].strip() | |
| else: | |
| sql = completion.strip() | |
| if not sql: | |
| rewards.append(0.0) | |
| continue | |
| hdr = {"X-Session-Id": str(uuid.uuid4())} | |
| try: | |
| client.post("/reset", json={"task_id": t_id}, headers=hdr).raise_for_status() | |
| resp = client.post( | |
| "/step", | |
| json={"action": {"action_type": "submit_query", "query": sql}}, | |
| headers=hdr, | |
| ) | |
| resp.raise_for_status() | |
| r = float(resp.json().get("reward", 0.0)) | |
| except Exception: | |
| r = 0.0 | |
| r += random.uniform(-1e-6, 1e-6) | |
| rewards.append(r) | |
| return rewards | |
| def eval_model_reward( | |
| model: AutoModelForCausalLM, | |
| tokenizer: AutoTokenizer, | |
| dataset: Dataset, | |
| *, | |
| max_items: int, | |
| ) -> float: | |
| subset = dataset.select(range(min(max_items, len(dataset)))) | |
| prompts = subset["prompt"] | |
| task_ids = subset["task_id"] | |
| completions: List[str] = [] | |
| for prompt in prompts: | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=256, | |
| do_sample=True, | |
| temperature=float(os.environ.get("EVAL_TEMPERATURE", "0.7")), | |
| top_p=float(os.environ.get("EVAL_TOP_P", "0.9")), | |
| renormalize_logits=True, | |
| remove_invalid_values=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| completions.append(tokenizer.decode(out[0], skip_special_tokens=True)) | |
| rewards = openenv_sql_reward_func(completions, task_ids) | |
| return float(sum(rewards) / max(len(rewards), 1)) | |
| # --- 3b. ARTIFACTS / PLOTS (REAL, FROM LOGS) --- | |
| class ArtifactPaths: | |
| root: Path | |
| def logs_jsonl(self) -> Path: | |
| return self.root / "train_log_history.jsonl" | |
| def metrics_json(self) -> Path: | |
| return self.root / "train_metrics.json" | |
| def reward_curve_png(self) -> Path: | |
| return self.root / "reward_curve.png" | |
| def _ensure_dir(path: Path) -> None: | |
| path.mkdir(parents=True, exist_ok=True) | |
| def save_log_history(log_history: List[Dict[str, Any]], paths: ArtifactPaths) -> None: | |
| _ensure_dir(paths.root) | |
| with paths.logs_jsonl.open("w", encoding="utf-8") as f: | |
| for row in log_history: | |
| f.write(json.dumps(row, ensure_ascii=False) + "\n") | |
| def extract_reward_series(log_history: List[Dict[str, Any]]) -> List[tuple[float, float]]: | |
| """ | |
| Returns [(step, reward_like_value)] extracted from trainer log_history. | |
| TRL log keys vary; this is resilient and will pick the most relevant. | |
| """ | |
| candidates = [ | |
| "reward", | |
| "rewards/mean", | |
| "rewards", | |
| "train/reward", | |
| "train/rewards", | |
| "objective/mean_reward", | |
| "mean_reward", | |
| ] | |
| series: List[tuple[float, float]] = [] | |
| for row in log_history: | |
| step = row.get("step") or row.get("global_step") or row.get("epoch") | |
| if step is None: | |
| continue | |
| value = None | |
| for key in candidates: | |
| if key in row and isinstance(row[key], (int, float)): | |
| value = float(row[key]) | |
| break | |
| if value is None: | |
| # fallback: pick any numeric key containing "reward" | |
| for k, v in row.items(): | |
| if "reward" in str(k).lower() and isinstance(v, (int, float)): | |
| value = float(v) | |
| break | |
| if value is None: | |
| continue | |
| series.append((float(step), value)) | |
| # de-dup by step while preserving order | |
| seen = set() | |
| deduped: List[tuple[float, float]] = [] | |
| for s, v in series: | |
| if s in seen: | |
| continue | |
| seen.add(s) | |
| deduped.append((s, v)) | |
| return deduped | |
| def write_metrics(log_history: List[Dict[str, Any]], reward_series: List[tuple[float, float]], paths: ArtifactPaths) -> None: | |
| metrics = { | |
| "generated_at_epoch_s": time.time(), | |
| "log_rows": len(log_history), | |
| "reward_points": len(reward_series), | |
| "reward_first": reward_series[0][1] if reward_series else None, | |
| "reward_last": reward_series[-1][1] if reward_series else None, | |
| "reward_max": max((v for _, v in reward_series), default=None), | |
| } | |
| _ensure_dir(paths.root) | |
| paths.metrics_json.write_text(json.dumps(metrics, indent=2), encoding="utf-8") | |
| def plot_reward_curve(reward_series: List[tuple[float, float]], paths: ArtifactPaths) -> None: | |
| if not reward_series: | |
| print("β οΈ No reward series found in log history; skipping plot.") | |
| return | |
| import matplotlib.pyplot as plt | |
| xs = [s for s, _ in reward_series] | |
| ys = [v for _, v in reward_series] | |
| plt.figure(figsize=(9, 4)) | |
| plt.plot(xs, ys, linewidth=2) | |
| plt.title("GRPO Reward Over Time (from run logs)") | |
| plt.xlabel("step") | |
| plt.ylabel("reward (extracted)") | |
| plt.grid(True, linestyle="--", alpha=0.4) | |
| _ensure_dir(paths.root) | |
| plt.tight_layout() | |
| plt.savefig(paths.reward_curve_png, dpi=200) | |
| print(f"Saved {paths.reward_curve_png}") | |
| def _resolve_report_to() -> str: | |
| raw = os.environ.get("TRL_REPORT_TO", "").strip().lower() | |
| if raw in ("", "auto"): | |
| return "wandb" if os.environ.get("WANDB_API_KEY") else "none" | |
| if raw in ("false", "no", "off", "none"): | |
| return "none" | |
| return raw | |
| # --- 4. Simple GRPO training loop (live OpenEnv rewards) --- | |
| def run_sota_train(): | |
| max_steps = int(os.environ.get("TRAIN_MAX_STEPS", "200")) | |
| out_dir = os.environ.get("OUTPUT_DIR", "./sota_results") | |
| print(f"Starting GRPO on {MODEL_NAME}...") | |
| print( | |
| f"OpenEnv={get_bridge_url()} | max_steps={max_steps} | " | |
| f"rows_per_task={os.environ.get('ROWS_PER_TASK', '48')} | " | |
| f"report_to={_resolve_report_to()}" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| use_cuda = torch.cuda.is_available() | |
| # L4/A10/A100 are typically more numerically stable with bf16 than fp16 for RL-style sampling. | |
| torch_dtype = torch.bfloat16 if use_cuda else torch.float32 | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch_dtype, | |
| device_map="auto", | |
| attn_implementation=os.environ.get("ATTN_IMPLEMENTATION", "eager"), | |
| ) | |
| # Runtime generation safety defaults (used by both eval and GRPO generate path). | |
| model.generation_config.remove_invalid_values = True | |
| model.generation_config.renormalize_logits = True | |
| model.generation_config.top_p = float(os.environ.get("GRPO_TOP_P", "0.9")) | |
| model.generation_config.temperature = float(os.environ.get("GRPO_TEMPERATURE", "0.7")) | |
| train_dataset = make_real_dataset() | |
| print("Quick baseline eval (pre-train)...") | |
| baseline_avg_reward = eval_model_reward(model, tokenizer, train_dataset, max_items=8) | |
| hard_eval_n = int(os.environ.get("HARD_EVAL_SAMPLES", "16")) | |
| hard_dataset = make_task_dataset("hard_finance_explosion", rows_per_task=hard_eval_n) | |
| base_hard_reward = eval_model_reward(model, tokenizer, hard_dataset, max_items=hard_eval_n) | |
| report_to = _resolve_report_to() | |
| tb_dir = Path(out_dir) / "tensorboard" | |
| if report_to == "tensorboard": | |
| _ensure_dir(tb_dir) | |
| per_device_bs = int(os.environ.get("PER_DEVICE_TRAIN_BS", "1")) | |
| grad_accum = int(os.environ.get("GRAD_ACCUM", "2")) | |
| requested_num_gen = int(os.environ.get("GRPO_NUM_GENERATIONS", "8")) | |
| effective_bs = max(1, per_device_bs * grad_accum) | |
| if effective_bs % requested_num_gen != 0: | |
| valid = [d for d in range(2, effective_bs + 1) if effective_bs % d == 0] | |
| num_gen = valid[-1] if valid else 2 | |
| print( | |
| f"Adjusting GRPO_NUM_GENERATIONS from {requested_num_gen} to {num_gen} " | |
| f"for effective batch size {effective_bs}." | |
| ) | |
| else: | |
| num_gen = requested_num_gen | |
| _cfg: Dict[str, Any] = dict( | |
| output_dir=out_dir, | |
| learning_rate=float(os.environ.get("TRAIN_LR", "5e-6")), | |
| per_device_train_batch_size=per_device_bs, | |
| gradient_accumulation_steps=grad_accum, | |
| num_generations=num_gen, | |
| max_completion_length=int(os.environ.get("GRPO_MAX_COMPLETION_LEN", "256")), | |
| temperature=float(os.environ.get("GRPO_TEMPERATURE", "0.7")), | |
| top_p=float(os.environ.get("GRPO_TOP_P", "0.9")), | |
| bf16=bool(use_cuda), | |
| fp16=False, | |
| num_train_epochs=int(os.environ.get("TRAIN_NUM_EPOCHS", "1")), | |
| max_steps=max_steps, | |
| logging_steps=int(os.environ.get("LOGGING_STEPS", "1")), | |
| logging_first_step=True, | |
| report_to=report_to, | |
| ) | |
| if report_to == "tensorboard": | |
| _cfg["logging_dir"] = str(tb_dir) | |
| training_args = GRPOConfig(**_cfg) | |
| trainer = GRPOTrainer( | |
| model=model, | |
| reward_funcs=[openenv_sql_reward_func], | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| processing_class=tokenizer, | |
| ) | |
| print("Training with live execution rewards against OpenEnv...") | |
| trainer.train() | |
| print("Quick eval (post-train)...") | |
| post_avg_reward = eval_model_reward(model, tokenizer, train_dataset, max_items=8) | |
| trained_hard_reward = eval_model_reward(model, tokenizer, hard_dataset, max_items=hard_eval_n) | |
| # --- Save artifacts (real logs/plots) --- | |
| artifacts = ArtifactPaths(root=Path(out_dir) / "artifacts") | |
| log_history = getattr(trainer.state, "log_history", []) or [] | |
| save_log_history(log_history, artifacts) | |
| reward_series = extract_reward_series(log_history) | |
| write_metrics(log_history, reward_series, artifacts) | |
| # augment metrics with before/after | |
| metrics_path = artifacts.metrics_json | |
| try: | |
| metrics = json.loads(metrics_path.read_text(encoding="utf-8")) | |
| except Exception: | |
| metrics = {} | |
| metrics.update( | |
| { | |
| "openenv_base_url": get_bridge_url(), | |
| "train_max_steps": max_steps, | |
| "model_name": MODEL_NAME, | |
| "baseline_avg_reward": baseline_avg_reward, | |
| "post_avg_reward": post_avg_reward, | |
| "delta_avg_reward": post_avg_reward - baseline_avg_reward, | |
| "base_hard_reward": base_hard_reward, | |
| "trained_hard_reward": trained_hard_reward, | |
| "delta_hard_reward": trained_hard_reward - base_hard_reward, | |
| "tensorboard_dir": str(tb_dir) if report_to == "tensorboard" else None, | |
| "report_to": report_to, | |
| } | |
| ) | |
| metrics_path.write_text(json.dumps(metrics, indent=2), encoding="utf-8") | |
| plot_reward_curve(reward_series, artifacts) | |
| try: | |
| import matplotlib.pyplot as plt | |
| labels = ["baseline", "post-train"] | |
| values = [baseline_avg_reward, post_avg_reward] | |
| plt.figure(figsize=(5, 4)) | |
| plt.bar(labels, values, color=["#94a3b8", "#22c55e"]) | |
| plt.ylim(0, max(1.0, max(values) * 1.1)) | |
| plt.title("Avg execution reward (sampled)") | |
| plt.ylabel("avg reward") | |
| out_path = artifacts.root / "before_after_avg_reward.png" | |
| plt.tight_layout() | |
| plt.savefig(out_path, dpi=200) | |
| print(f"Saved {out_path}") | |
| except Exception as e: | |
| print(f"Could not generate before/after plot: {e}") | |
| model_dir = os.environ.get("MODEL_SAVE_DIR", "./sota_sql_agent_full") | |
| print("\nSaving trained model locally...") | |
| model.save_pretrained(model_dir) | |
| hub_id = os.environ.get("MODEL_HUB_REPO_ID", os.environ.get("HF_HUB_REPO_ID", "md896/sql-debug-agent-qwen05b-grpo")) | |
| token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") | |
| if os.environ.get("SKIP_HUB_PUSH", "").strip() in ("1", "true", "yes"): | |
| print("SKIP_HUB_PUSH set β not pushing to Hub.") | |
| else: | |
| try: | |
| model.push_to_hub(hub_id, token=token) | |
| tokenizer.push_to_hub(hub_id, token=token) | |
| print(f"Pushed trained model to https://huggingface.co/{hub_id}") | |
| except Exception as e: | |
| print(f"Hub push failed (set HF_TOKEN / MODEL_HUB_REPO_ID or SKIP_HUB_PUSH=1): {e}") | |
| # Upload run artifacts back to the Space repo so you can download/view them. | |
| artifact_space = os.environ.get("ARTIFACT_SPACE_ID", "md896/sql-debug-env") | |
| run_tag = time.strftime("%Y%m%d-%H%M%S") | |
| try: | |
| if token: | |
| api = HfApi(token=token) | |
| api.upload_folder( | |
| repo_id=artifact_space, | |
| repo_type="space", | |
| folder_path=str(artifacts.root), | |
| path_in_repo=f"artifacts/runs/{run_tag}", | |
| commit_message=f"Add training artifacts {run_tag}", | |
| ) | |
| print(f"Uploaded artifacts to https://huggingface.co/spaces/{artifact_space}/tree/main/artifacts/runs/{run_tag}") | |
| else: | |
| print("No HF token in job env; skipping artifact upload.") | |
| except Exception as e: | |
| print(f"Artifact upload failed: {e}") | |
| print(f"\nTraining artifacts under {artifacts.root}") | |
| if __name__ == "__main__": | |
| run_sota_train() | |