code-gen-assistant / src /finetune /train_codet5.py
Rushabh147's picture
Initial deploy to HF Spaces (clean history, LFS for all binaries)
b89e6d6
Raw
History Blame Contribute Delete
2.66 kB
"""Phase 4: fine-tune CodeT5+ on docstring -> code.
This is the second experimental arm (a tuned small model) to compare against
frozen-LLM + RAG. Runs on a single mid-range GPU; raise subset/epochs for the
real result.
"""
from __future__ import annotations
import sys
from pathlib import Path
import pandas as pd
sys.path.append(str(Path(__file__).resolve().parents[2]))
from src.config import load_config # noqa: E402
CHECKPOINT = "Salesforce/codet5p-220m"
def finetune(subset_size: int = 5000, epochs: int = 1, out_dir: str = "data/codet5p-ft",
cfg=None):
from datasets import Dataset
from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer,
DataCollatorForSeq2Seq, Seq2SeqTrainer,
Seq2SeqTrainingArguments)
import torch
cfg = cfg or load_config()
train_path = Path(cfg.paths.processed_dir) / "train.parquet"
df = pd.read_parquet(train_path).head(subset_size)
tok = AutoTokenizer.from_pretrained(CHECKPOINT)
model = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT)
def to_features(batch):
x = tok(batch["docstring"], max_length=64, truncation=True, padding="max_length")
y = tok(text_target=batch["code"], max_length=256, truncation=True,
padding="max_length")
x["labels"] = y["input_ids"]
return x
ds = Dataset.from_pandas(df[["docstring", "code"]]).map(
to_features, batched=True, remove_columns=["docstring", "code"])
args = Seq2SeqTrainingArguments(
output_dir=out_dir, per_device_train_batch_size=8, num_train_epochs=epochs,
learning_rate=5e-5, logging_steps=50, save_strategy="epoch",
fp16=torch.cuda.is_available(), report_to="none")
trainer = Seq2SeqTrainer(
model=model, args=args, train_dataset=ds,
data_collator=DataCollatorForSeq2Seq(tok, model=model))
trainer.train()
trainer.save_model(out_dir)
tok.save_pretrained(out_dir)
print(f"[finetune] saved to {out_dir}")
return out_dir
def make_t5_generate_fn(model_dir: str):
"""Return generate_fn(intent)->code for plugging a tuned CodeT5+ into eval."""
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
tok = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
model.eval()
def generate_fn(intent: str) -> str:
ids = tok(intent, return_tensors="pt", truncation=True, max_length=64).input_ids
out = model.generate(ids.to(model.device), max_length=256)
return tok.decode(out[0], skip_special_tokens=True)
return generate_fn