modelforge-backend / agents /services /code_generator.py
ModelForge CI
deploy: 2026-06-19 19:24 UTC
6761f70
Raw
History Blame Contribute Delete
10.3 kB
"""
Standalone training script generator.
Takes pipeline artifacts (TaskSpec + DataProfile + ModelRecipe + TrainingResult)
and emits a self-contained, copy-paste-runnable Python script.
The script is AST-validated before returning to guarantee syntax correctness.
"""
from __future__ import annotations
import ast
import json
from typing import Any
class CodeGenerationError(Exception):
"""Raised when the generated script fails AST validation."""
def generate_training_script(
task_spec: dict[str, Any],
data_profile: dict[str, Any],
model_recipe: dict[str, Any],
training_result: dict[str, Any],
) -> str:
"""
Generate a standalone Python training script from pipeline artifacts.
Returns:
str β€” the script source code.
Raises:
CodeGenerationError β€” if the output fails ast.parse().
"""
approach = str(model_recipe.get("training_approach", "full_finetune"))
base_model = str(model_recipe.get("base_model", "distilbert-base-uncased"))
# Resolve label names (prefer training_result β†’ task_spec β†’ data_profile)
label_names = (
training_result.get("label_names")
or task_spec.get("label_names")
or list(data_profile.get("label_distribution", {}).keys())
or []
)
label_names = [str(l) for l in label_names]
is_lora = approach in ("lora", "qlora")
is_qlora = approach == "qlora"
lr = float(model_recipe.get("learning_rate", 2e-5) or 2e-5)
epochs = int(model_recipe.get("num_epochs", 3) or 3)
batch = int(model_recipe.get("batch_size", 16) or 16)
max_len = int(model_recipe.get("max_length", 128) or 128)
warmup = float(model_recipe.get("warmup_ratio", 0.1) or 0.1)
wd = float(model_recipe.get("weight_decay", 0.01) or 0.01)
lora_r = int(model_recipe.get("lora_r", 16) or 16)
lora_alpha = int(model_recipe.get("lora_alpha") or lora_r * 2)
input_col = str(task_spec.get("input_column", "text"))
label_col = str(task_spec.get("label_column", "label"))
acc = training_result.get("accuracy")
f1 = training_result.get("f1")
metrics_str = (
f"Accuracy={acc*100:.1f}% F1={f1:.3f}"
if acc is not None and f1 is not None
else "metrics unavailable"
)
# ── Build sections ─────────────────────────────────────────────────────────
extra_imports = []
if is_qlora:
extra_imports.append("from transformers import BitsAndBytesConfig")
extra_imports_str = "\n".join(extra_imports)
# CONFIG dict lines
config_lines = [
f' "base_model": {json.dumps(base_model)},',
f' "training_approach": {json.dumps(approach)},',
f' "label_names": {json.dumps(label_names)},',
f' "input_column": {json.dumps(input_col)},',
f' "label_column": {json.dumps(label_col)},',
f' "learning_rate": {lr!r},',
f' "num_epochs": {epochs},',
f' "batch_size": {batch},',
f' "max_length": {max_len},',
f' "warmup_ratio": {warmup},',
f' "weight_decay": {wd},',
]
config_block = "\n".join(config_lines)
# QLoRA config block (inside train())
if is_qlora:
bnb_block = (
" bnb_config = BitsAndBytesConfig(\n"
" load_in_4bit=True,\n"
' bnb_4bit_quant_type="nf4",\n'
" bnb_4bit_use_double_quant=True,\n"
" bnb_4bit_compute_dtype=torch.bfloat16,\n"
" )"
)
model_load_block = (
' model = AutoModelForSequenceClassification.from_pretrained(\n'
' CONFIG["base_model"],\n'
' num_labels=len(CONFIG["label_names"]),\n'
' quantization_config=bnb_config,\n'
' device_map="auto",\n'
' )'
)
else:
bnb_block = ""
model_load_block = (
' model = AutoModelForSequenceClassification.from_pretrained(\n'
' CONFIG["base_model"],\n'
' num_labels=len(CONFIG["label_names"]),\n'
' )'
)
# LoRA adapter block (inside train(), after model load)
if is_lora:
lora_block = (
" # ── LoRA adapter ─────────────────────────────────────────────────\n"
" from peft import get_peft_model, LoraConfig, TaskType\n"
" lora_cfg = LoraConfig(\n"
" task_type=TaskType.SEQ_CLS,\n"
f" r={lora_r},\n"
f" lora_alpha={lora_alpha},\n"
' target_modules=["query", "key", "value"],\n'
' bias="none",\n'
" )\n"
" model = get_peft_model(model, lora_cfg)\n"
" model.print_trainable_parameters()"
)
else:
lora_block = ""
# Build the full script using explicit string concatenation
sections = []
sections.append(f'''\
#!/usr/bin/env python3
"""
Auto-generated training script by ModelForge.
Run result: {metrics_str}
Base model: {base_model} Approach: {approach}
Usage:
pip install transformers datasets peft accelerate scikit-learn
python train.py --data_path path/to/your/data.csv
"""
import argparse
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from datasets import Dataset
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TrainingArguments,
Trainer,
EarlyStoppingCallback,
DataCollatorWithPadding,
)''')
if extra_imports_str:
sections.append(extra_imports_str)
sections.append(f'''\
# ── Configuration ─────────────────────────────────────────────────────────────
CONFIG = {{
{config_block}
}}''')
sections.append('''\
def load_data(data_path: str):
df = pd.read_csv(data_path)
assert CONFIG["input_column"] in df.columns, f"Missing column: {CONFIG['input_column']}"
assert CONFIG["label_column"] in df.columns, f"Missing column: {CONFIG['label_column']}"
df = df.dropna(subset=[CONFIG["input_column"], CONFIG["label_column"]])
df[CONFIG["input_column"]] = df[CONFIG["input_column"]].astype(str)
df[CONFIG["label_column"]] = df[CONFIG["label_column"]].astype(str)
label_names = CONFIG["label_names"] or sorted(df[CONFIG["label_column"]].unique().tolist())
label2id = {lbl: i for i, lbl in enumerate(label_names)}
df["label_int"] = df[CONFIG["label_column"]].map(label2id)
df = df.dropna(subset=["label_int"])
df["label_int"] = df["label_int"].astype(int)
train_df, test_df = train_test_split(
df, test_size=0.2, random_state=42,
stratify=df["label_int"] if df["label_int"].nunique() > 1 else None,
)
return train_df, test_df, label_names
def tokenize(df, tokenizer):
hf_ds = Dataset.from_dict({
"text": df[CONFIG["input_column"]].tolist(),
"label": df["label_int"].tolist(),
})
return hf_ds.map(
lambda b: tokenizer(
b["text"],
truncation=True,
max_length=CONFIG["max_length"],
padding=False,
),
batched=True,
)
def compute_metrics(p):
preds = p.predictions.argmax(axis=-1)
refs = p.label_ids
return {
"accuracy": accuracy_score(refs, preds),
"f1": f1_score(refs, preds, average="weighted", zero_division=0),
}
def train(data_path: str, output_dir: str = "trained_model"):
tokenizer = AutoTokenizer.from_pretrained(CONFIG["base_model"])
train_df, test_df, label_names = load_data(data_path)
train_ds = tokenize(train_df, tokenizer)
test_ds = tokenize(test_df, tokenizer)
''')
if bnb_block:
sections.append(bnb_block)
sections.append(model_load_block)
if lora_block:
sections.append("")
sections.append(lora_block)
sections.append('''\
collator = DataCollatorWithPadding(tokenizer=tokenizer)
args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=CONFIG["num_epochs"],
per_device_train_batch_size=CONFIG["batch_size"],
per_device_eval_batch_size=CONFIG["batch_size"] * 2,
learning_rate=CONFIG["learning_rate"],
warmup_ratio=CONFIG["warmup_ratio"],
weight_decay=CONFIG["weight_decay"],
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="f1",
fp16=torch.cuda.is_available(),
report_to="none",
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_ds,
eval_dataset=test_ds,
tokenizer=tokenizer,
data_collator=collator,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)
trainer.train()
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"\\nModel saved to {output_dir}")
results = trainer.evaluate()
print(f"Eval accuracy: {results.get('eval_accuracy', 0)*100:.1f}%")
print(f"Eval F1: {results.get('eval_f1', 0):.3f}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", required=True, help="Path to your CSV dataset")
parser.add_argument("--output_dir", default="trained_model")
args = parser.parse_args()
train(data_path=args.data_path, output_dir=args.output_dir)
''')
script = "\n".join(sections)
# Validate syntax before returning
try:
ast.parse(script)
except SyntaxError as exc:
raise CodeGenerationError(
f"Generated script has a syntax error at line {exc.lineno}: {exc.msg}"
) from exc
return script