tiny-transformer-trainer / train_app.py
Leore42's picture
Add training app
6df9039 verified
"""
Simple Transformer Training Environment
Train small GPT models from user-uploaded text data.
"""
import os
import json
import csv
import tempfile
import shutil
from pathlib import Path
from typing import Optional, List, Tuple
import gradio as gr
import pandas as pd
from tokenizers import ByteLevelBPETokenizer
from transformers import (
GPT2Config, GPT2LMHeadModel,
PreTrainedTokenizerFast,
DataCollatorForLanguageModeling,
TrainingArguments, Trainer,
)
from datasets import Dataset
# ---------------------------------------------------------------------------
# Constants & defaults
# ---------------------------------------------------------------------------
DEFAULT_OUTPUT_DIR = "./trained_model_output"
HIDDEN_SIZES = [128, 256, 384, 512]
LAYER_COUNTS = [2, 4, 6, 8, 12]
HEAD_COUNTS = [2, 4, 8]
MAX_SEQ_LENS = [128, 256, 512, 1024]
PROMPT_TEMPLATE_DEFAULT = "{question}\n{answer}"
# ---------------------------------------------------------------------------
# Dataset loading helpers
# ---------------------------------------------------------------------------
def load_text_from_txt(filepath: str) -> List[str]:
"""Load plain text from .txt file."""
with open(filepath, "r", encoding="utf-8") as f:
text = f.read()
# Split into chunks on double newlines for variety
chunks = [chunk.strip() for chunk in text.split("\n\n") if chunk.strip()]
if len(chunks) < 2:
# If splitting produced too few chunks, split by single newline
chunks = [line.strip() for line in text.split("\n") if line.strip()]
return chunks
def load_qa_from_csv(
filepath: str,
question_col: str,
answer_col: str,
template: str = PROMPT_TEMPLATE_DEFAULT,
) -> List[str]:
"""Load Q&A pairs from CSV and format them."""
df = pd.read_csv(filepath)
if question_col not in df.columns or answer_col not in df.columns:
raise ValueError(
f"CSV columns: {list(df.columns)} β€” "
f"could not find '{question_col}' or '{answer_col}'"
)
texts = []
for _, row in df.iterrows():
q = str(row[question_col])
a = str(row[answer_col])
texts.append(template.format(question=q, answer=a))
return texts
def load_qa_from_json(
filepath: str,
question_col: str,
answer_col: str,
template: str = PROMPT_TEMPLATE_DEFAULT,
) -> List[str]:
"""Load Q&A pairs from JSON array and format them."""
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict) and "data" in data:
data = data["data"]
if not isinstance(data, list):
raise ValueError("JSON file must contain a top-level list or a dict with a 'data' key.")
texts = []
for item in data:
if not isinstance(item, dict):
continue
q = str(item.get(question_col, ""))
a = str(item.get(answer_col, ""))
if q or a:
texts.append(template.format(question=q, answer=a))
return texts
def detect_columns_csv(filepath: str) -> List[str]:
"""Peek at CSV columns."""
df = pd.read_csv(filepath, nrows=2)
return list(df.columns)
def detect_columns_json(filepath: str) -> List[str]:
"""Peek at JSON keys."""
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict) and "data" in data:
data = data["data"]
if isinstance(data, list) and len(data) > 0 and isinstance(data[0], dict):
return list(data[0].keys())
return []
# ---------------------------------------------------------------------------
# Tokenizer training
# ---------------------------------------------------------------------------
def train_custom_tokenizer(texts: List[str], vocab_size: int, output_dir: str) -> PreTrainedTokenizerFast:
"""Train a ByteLevel BPE tokenizer on the provided texts."""
os.makedirs(output_dir, exist_ok=True)
tokenizer_raw = ByteLevelBPETokenizer(add_prefix_space=True)
tokenizer_raw.train_from_iterator(
texts,
vocab_size=vocab_size,
min_frequency=2,
special_tokens=["<s>", "<pad>", "</s>", "<unk>"],
)
tokenizer_path = os.path.join(output_dir, "tokenizer.json")
tokenizer_raw.save(tokenizer_path)
tokenizer = PreTrainedTokenizerFast(
tokenizer_file=tokenizer_path,
bos_token="<s>",
eos_token="</s>",
pad_token="<pad>",
unk_token="<unk>",
)
tokenizer.save_pretrained(output_dir)
return tokenizer
# ---------------------------------------------------------------------------
# Model creation
# ---------------------------------------------------------------------------
def create_model(
vocab_size: int,
hidden_size: int,
num_layers: int,
num_heads: int,
max_length: int,
) -> GPT2LMHeadModel:
"""Create a small GPT-2 model from config."""
config = GPT2Config(
vocab_size=vocab_size,
n_positions=max_length,
n_embd=hidden_size,
n_layer=num_layers,
n_head=num_heads,
n_inner=hidden_size * 4,
bos_token_id=0,
eos_token_id=1,
pad_token_id=2,
)
model = GPT2LMHeadModel(config)
return model
# ---------------------------------------------------------------------------
# Training
# ---------------------------------------------------------------------------
def tokenize_dataset(dataset: Dataset, tokenizer: PreTrainedTokenizerFast, max_length: int):
def tokenize_fn(examples):
return tokenizer(
examples["text"],
truncation=True,
max_length=max_length,
padding="max_length",
)
return dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
class TrainingStatus:
"""Thread-safe(ish) status holder updated by the Trainer callback."""
def __init__(self):
self.logs: List[str] = []
self.step = 0
self.total_steps = 0
self.loss: Optional[float] = None
self.done = False
self.error: Optional[str] = None
def append(self, msg: str):
self.logs.append(msg)
def get_text(self) -> str:
return "\n".join(self.logs[-200:]) # Keep last 200 lines
status = TrainingStatus()
class StatusCallback:
"""HuggingFace Trainer callback that feeds our UI."""
def __init__(self, total_steps: int):
self.total_steps = total_steps
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None:
return
step = state.global_step
status.step = step
if "loss" in logs:
status.loss = logs["loss"]
msg = f"Step {step}/{self.total_steps} β€” loss={logs.get('loss', 'n/a'):.4f}"
status.append(msg)
def on_train_end(self, args, state, control, **kwargs):
status.append("βœ… Training complete!")
status.done = True
# ---------------------------------------------------------------------------
# Main training orchestrator
# ---------------------------------------------------------------------------
def run_training(
file_obj,
file_type: str,
question_col: str,
answer_col: str,
prompt_template: str,
vocab_size: int,
hidden_size: int,
num_layers: int,
num_heads: int,
max_length: int,
num_epochs: int,
batch_size: int,
learning_rate: float,
output_dir: str,
progress=gr.Progress(),
):
"""
Main training entry-point used by Gradio.
"""
global status
status = TrainingStatus()
# --- 1. Load data ---
status.append("πŸ“‚ Loading data…")
yield status.get_text(), None
if file_obj is None:
status.error = "No file uploaded."
yield f"❌ Error: {status.error}", None
return
filepath = file_obj.name
ext = Path(filepath).suffix.lower()
if ext == ".txt":
texts = load_text_from_txt(filepath)
elif ext == ".csv":
texts = load_qa_from_csv(filepath, question_col, answer_col, prompt_template)
elif ext == ".json":
texts = load_qa_from_json(filepath, question_col, answer_col, prompt_template)
else:
status.error = f"Unsupported file extension: {ext}"
yield f"❌ Error: {status.error}", None
return
if len(texts) == 0:
status.error = "No valid text samples found in file."
yield f"❌ Error: {status.error}", None
return
status.append(f"βœ… Loaded {len(texts)} text samples.")
yield status.get_text(), None
# --- 2. Train tokenizer ---
status.append("πŸ”€ Training tokenizer…")
yield status.get_text(), None
tokenizer_output = os.path.join(output_dir, "tokenizer")
os.makedirs(tokenizer_output, exist_ok=True)
tokenizer = train_custom_tokenizer(texts, vocab_size, tokenizer_output)
status.append(f"βœ… Tokenizer saved to {tokenizer_output}")
yield status.get_text(), None
# --- 3. Create model ---
status.append("πŸ—οΈ Creating model…")
yield status.get_text(), None
model = create_model(
vocab_size=tokenizer.vocab_size,
hidden_size=hidden_size,
num_layers=num_layers,
num_heads=num_heads,
max_length=max_length,
)
status.append(f"βœ… Model created: {num_layers} layers, {hidden_size} hidden, {num_heads} heads")
yield status.get_text(), None
# --- 4. Prepare dataset ---
status.append("πŸ“Š Preparing dataset…")
yield status.get_text(), None
dataset = Dataset.from_dict({"text": texts})
tokenized = tokenize_dataset(dataset, tokenizer, max_length)
status.append(f"βœ… Dataset tokenized: {len(tokenized)} samples")
yield status.get_text(), None
# --- 5. Train ---
status.append(f"πŸš€ Starting training ({num_epochs} epochs, lr={learning_rate})…")
yield status.get_text(), None
os.makedirs(output_dir, exist_ok=True)
steps_per_epoch = max(1, len(tokenized) // batch_size)
total_steps = steps_per_epoch * num_epochs
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
learning_rate=learning_rate,
weight_decay=0.01,
logging_strategy="steps",
logging_steps=max(1, total_steps // 20),
save_strategy="epoch",
save_total_limit=2,
warmup_steps=max(1, total_steps // 10),
fp16=False,
dataloader_num_workers=0,
disable_tqdm=True,
logging_first_step=True,
)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Build Trainer kwargs β€” detect whether 'processing_class' or 'tokenizer' is supported
import inspect
sig = inspect.signature(Trainer.__init__)
trainer_kwargs = {
"model": model,
"args": training_args,
"train_dataset": tokenized,
"data_collator": data_collator,
"callbacks": [StatusCallback(total_steps)],
}
if "processing_class" in sig.parameters:
trainer_kwargs["processing_class"] = tokenizer
elif "tokenizer" in sig.parameters:
trainer_kwargs["tokenizer"] = tokenizer
trainer = Trainer(**trainer_kwargs)
trainer.train()
# --- 6. Save everything ---
status.append("πŸ’Ύ Saving model & tokenizer…")
yield status.get_text(), None
model.save_pretrained(os.path.join(output_dir, "model"))
tokenizer.save_pretrained(os.path.join(output_dir, "tokenizer"))
# Also save a combined README
readme_path = os.path.join(output_dir, "README.md")
with open(readme_path, "w", encoding="utf-8") as f:
f.write(f"""# Trained Transformer Model
## Architecture
- **Type:** GPT-2 causal language model
- **Hidden size:** {hidden_size}
- **Layers:** {num_layers}
- **Attention heads:** {num_heads}
- **Max sequence length:** {max_length}
- **Vocab size:** {vocab_size}
## Training
- **Epochs:** {num_epochs}
- **Batch size:** {batch_size}
- **Learning rate:** {learning_rate}
- **Samples:** {len(texts)}
## Files
- `model/` β€” model weights + config
- `tokenizer/` β€” tokenizer vocab + config
- `tokenizer/tokenizer.json` β€” raw tokenizer file
## Usage
```python
from transformers import GPT2LMHeadModel, PreTrainedTokenizerFast
model = GPT2LMHeadModel.from_pretrained("{output_dir}/model")
tokenizer = PreTrainedTokenizerFast.from_pretrained("{output_dir}/tokenizer")
```
""")
# Package as a zip for easy download
zip_path = shutil.make_archive(output_dir, "zip", output_dir)
status.append(f"βœ… All done! Model saved to `{output_dir}`")
status.append(f"πŸ“¦ Download zip: `{zip_path}`")
status.done = True
yield status.get_text(), zip_path
# ---------------------------------------------------------------------------
# Gradio UI helpers
# ---------------------------------------------------------------------------
def update_ui_visibility(file_type: str):
"""Show/hide Q&A column inputs depending on file type."""
if file_type == "Plain text (.txt)":
return [
gr.update(visible=False), # question_col
gr.update(visible=False), # answer_col
gr.update(visible=False), # prompt_template
gr.update(placeholder="Upload a .txt file with raw text"),
]
else:
return [
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(placeholder=f"Upload a {file_type.split('(')[1].replace(')', '')} file"),
]
def auto_detect_cols(file_obj, file_type: str):
"""Auto-detect columns for CSV/JSON and return suggestions."""
if file_obj is None or file_type == "Plain text (.txt)":
return "question", "answer"
filepath = file_obj.name
ext = Path(filepath).suffix.lower()
try:
if ext == ".csv":
cols = detect_columns_csv(filepath)
elif ext == ".json":
cols = detect_columns_json(filepath)
else:
return "question", "answer"
except Exception:
return "question", "answer"
# Simple heuristics
q_col = next((c for c in cols if "question" in c.lower() or "q" == c.lower() or "prompt" in c.lower()), cols[0] if cols else "question")
a_col = next((c for c in cols if "answer" in c.lower() or "a" == c.lower() or "response" in c.lower() or "output" in c.lower()), cols[1] if len(cols) > 1 else (cols[0] if cols else "answer"))
return q_col, a_col
# ---------------------------------------------------------------------------
# Gradio App
# ---------------------------------------------------------------------------
with gr.Blocks(title="🧠 Tiny Transformer Trainer") as demo:
gr.Markdown("""
# 🧠 Tiny Transformer Trainer
Upload your text data and train a small GPT model from scratch.
Supports `.txt` (plain text), `.csv` and `.json` (Q&A pairs).
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸ“€ Data Upload")
file_type = gr.Dropdown(
choices=["Plain text (.txt)", "CSV Q&A pairs (.csv)", "JSON Q&A pairs (.json)"],
value="Plain text (.txt)",
label="Dataset type",
)
file_input = gr.File(label="Upload file", type="filepath")
question_col = gr.Textbox(value="question", label="Question/prompt column name", visible=False)
answer_col = gr.Textbox(value="answer", label="Answer/response column name", visible=False)
prompt_template = gr.Textbox(
value="{question}\n{answer}",
label="Prompt template (use {question} and {answer})",
visible=False,
)
auto_detect_btn = gr.Button("πŸ” Auto-detect columns", visible=False)
gr.Markdown("---")
gr.Markdown("### πŸ—οΈ Model Architecture")
vocab_size = gr.Slider(1000, 32768, value=10000, step=1000, label="Vocabulary size")
hidden_size = gr.Dropdown(choices=HIDDEN_SIZES, value=256, label="Hidden size (embedding dim)")
num_layers = gr.Dropdown(choices=LAYER_COUNTS, value=4, label="Number of layers")
num_heads = gr.Dropdown(choices=HEAD_COUNTS, value=4, label="Attention heads")
max_length = gr.Dropdown(choices=MAX_SEQ_LENS, value=256, label="Max sequence length")
gr.Markdown("---")
gr.Markdown("### βš™οΈ Training Settings")
num_epochs = gr.Slider(1, 20, value=3, step=1, label="Epochs")
batch_size = gr.Slider(1, 32, value=8, step=1, label="Batch size")
learning_rate = gr.Number(value=5e-4, label="Learning rate")
output_dir = gr.Textbox(value=DEFAULT_OUTPUT_DIR, label="Output directory")
train_btn = gr.Button("πŸš€ Start Training", variant="primary")
with gr.Column(scale=1):
gr.Markdown("### πŸ“‹ Training Log")
log_box = gr.Textbox(label="", lines=25, interactive=False, show_label=False)
zip_download = gr.File(label="πŸ“¦ Download trained model (zip)", visible=True)
# -------------------------------------------------------------------
# Event wiring
# -------------------------------------------------------------------
def on_file_type_change(ft):
return update_ui_visibility(ft)
file_type.change(
on_file_type_change,
inputs=[file_type],
outputs=[question_col, answer_col, prompt_template, file_input],
)
# Also toggle auto-detect button visibility
file_type.change(
lambda ft: gr.update(visible=(ft != "Plain text (.txt)")),
inputs=[file_type],
outputs=[auto_detect_btn],
)
def on_auto_detect(file_obj, ft):
q, a = auto_detect_cols(file_obj, ft)
return q, a
auto_detect_btn.click(
on_auto_detect,
inputs=[file_input, file_type],
outputs=[question_col, answer_col],
)
train_btn.click(
run_training,
inputs=[
file_input, file_type, question_col, answer_col, prompt_template,
vocab_size, hidden_size, num_layers, num_heads, max_length,
num_epochs, batch_size, learning_rate, output_dir,
],
outputs=[log_box, zip_download],
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, theme=gr.themes.Soft())