Spaces:
Sleeping
Sleeping
| """ | |
| Simple Transformer Training Environment | |
| Train small GPT models from user-uploaded text data. | |
| """ | |
| import os | |
| import json | |
| import csv | |
| import tempfile | |
| import shutil | |
| from pathlib import Path | |
| from typing import Optional, List, Tuple | |
| import gradio as gr | |
| import pandas as pd | |
| from tokenizers import ByteLevelBPETokenizer | |
| from transformers import ( | |
| GPT2Config, GPT2LMHeadModel, | |
| PreTrainedTokenizerFast, | |
| DataCollatorForLanguageModeling, | |
| TrainingArguments, Trainer, | |
| ) | |
| from datasets import Dataset | |
| # --------------------------------------------------------------------------- | |
| # Constants & defaults | |
| # --------------------------------------------------------------------------- | |
| DEFAULT_OUTPUT_DIR = "./trained_model_output" | |
| HIDDEN_SIZES = [128, 256, 384, 512] | |
| LAYER_COUNTS = [2, 4, 6, 8, 12] | |
| HEAD_COUNTS = [2, 4, 8] | |
| MAX_SEQ_LENS = [128, 256, 512, 1024] | |
| PROMPT_TEMPLATE_DEFAULT = "{question}\n{answer}" | |
| # --------------------------------------------------------------------------- | |
| # Dataset loading helpers | |
| # --------------------------------------------------------------------------- | |
| def load_text_from_txt(filepath: str) -> List[str]: | |
| """Load plain text from .txt file.""" | |
| with open(filepath, "r", encoding="utf-8") as f: | |
| text = f.read() | |
| # Split into chunks on double newlines for variety | |
| chunks = [chunk.strip() for chunk in text.split("\n\n") if chunk.strip()] | |
| if len(chunks) < 2: | |
| # If splitting produced too few chunks, split by single newline | |
| chunks = [line.strip() for line in text.split("\n") if line.strip()] | |
| return chunks | |
| def load_qa_from_csv( | |
| filepath: str, | |
| question_col: str, | |
| answer_col: str, | |
| template: str = PROMPT_TEMPLATE_DEFAULT, | |
| ) -> List[str]: | |
| """Load Q&A pairs from CSV and format them.""" | |
| df = pd.read_csv(filepath) | |
| if question_col not in df.columns or answer_col not in df.columns: | |
| raise ValueError( | |
| f"CSV columns: {list(df.columns)} β " | |
| f"could not find '{question_col}' or '{answer_col}'" | |
| ) | |
| texts = [] | |
| for _, row in df.iterrows(): | |
| q = str(row[question_col]) | |
| a = str(row[answer_col]) | |
| texts.append(template.format(question=q, answer=a)) | |
| return texts | |
| def load_qa_from_json( | |
| filepath: str, | |
| question_col: str, | |
| answer_col: str, | |
| template: str = PROMPT_TEMPLATE_DEFAULT, | |
| ) -> List[str]: | |
| """Load Q&A pairs from JSON array and format them.""" | |
| with open(filepath, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| if isinstance(data, dict) and "data" in data: | |
| data = data["data"] | |
| if not isinstance(data, list): | |
| raise ValueError("JSON file must contain a top-level list or a dict with a 'data' key.") | |
| texts = [] | |
| for item in data: | |
| if not isinstance(item, dict): | |
| continue | |
| q = str(item.get(question_col, "")) | |
| a = str(item.get(answer_col, "")) | |
| if q or a: | |
| texts.append(template.format(question=q, answer=a)) | |
| return texts | |
| def detect_columns_csv(filepath: str) -> List[str]: | |
| """Peek at CSV columns.""" | |
| df = pd.read_csv(filepath, nrows=2) | |
| return list(df.columns) | |
| def detect_columns_json(filepath: str) -> List[str]: | |
| """Peek at JSON keys.""" | |
| with open(filepath, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| if isinstance(data, dict) and "data" in data: | |
| data = data["data"] | |
| if isinstance(data, list) and len(data) > 0 and isinstance(data[0], dict): | |
| return list(data[0].keys()) | |
| return [] | |
| # --------------------------------------------------------------------------- | |
| # Tokenizer training | |
| # --------------------------------------------------------------------------- | |
| def train_custom_tokenizer(texts: List[str], vocab_size: int, output_dir: str) -> PreTrainedTokenizerFast: | |
| """Train a ByteLevel BPE tokenizer on the provided texts.""" | |
| os.makedirs(output_dir, exist_ok=True) | |
| tokenizer_raw = ByteLevelBPETokenizer(add_prefix_space=True) | |
| tokenizer_raw.train_from_iterator( | |
| texts, | |
| vocab_size=vocab_size, | |
| min_frequency=2, | |
| special_tokens=["<s>", "<pad>", "</s>", "<unk>"], | |
| ) | |
| tokenizer_path = os.path.join(output_dir, "tokenizer.json") | |
| tokenizer_raw.save(tokenizer_path) | |
| tokenizer = PreTrainedTokenizerFast( | |
| tokenizer_file=tokenizer_path, | |
| bos_token="<s>", | |
| eos_token="</s>", | |
| pad_token="<pad>", | |
| unk_token="<unk>", | |
| ) | |
| tokenizer.save_pretrained(output_dir) | |
| return tokenizer | |
| # --------------------------------------------------------------------------- | |
| # Model creation | |
| # --------------------------------------------------------------------------- | |
| def create_model( | |
| vocab_size: int, | |
| hidden_size: int, | |
| num_layers: int, | |
| num_heads: int, | |
| max_length: int, | |
| ) -> GPT2LMHeadModel: | |
| """Create a small GPT-2 model from config.""" | |
| config = GPT2Config( | |
| vocab_size=vocab_size, | |
| n_positions=max_length, | |
| n_embd=hidden_size, | |
| n_layer=num_layers, | |
| n_head=num_heads, | |
| n_inner=hidden_size * 4, | |
| bos_token_id=0, | |
| eos_token_id=1, | |
| pad_token_id=2, | |
| ) | |
| model = GPT2LMHeadModel(config) | |
| return model | |
| # --------------------------------------------------------------------------- | |
| # Training | |
| # --------------------------------------------------------------------------- | |
| def tokenize_dataset(dataset: Dataset, tokenizer: PreTrainedTokenizerFast, max_length: int): | |
| def tokenize_fn(examples): | |
| return tokenizer( | |
| examples["text"], | |
| truncation=True, | |
| max_length=max_length, | |
| padding="max_length", | |
| ) | |
| return dataset.map(tokenize_fn, batched=True, remove_columns=["text"]) | |
| class TrainingStatus: | |
| """Thread-safe(ish) status holder updated by the Trainer callback.""" | |
| def __init__(self): | |
| self.logs: List[str] = [] | |
| self.step = 0 | |
| self.total_steps = 0 | |
| self.loss: Optional[float] = None | |
| self.done = False | |
| self.error: Optional[str] = None | |
| def append(self, msg: str): | |
| self.logs.append(msg) | |
| def get_text(self) -> str: | |
| return "\n".join(self.logs[-200:]) # Keep last 200 lines | |
| status = TrainingStatus() | |
| class StatusCallback: | |
| """HuggingFace Trainer callback that feeds our UI.""" | |
| def __init__(self, total_steps: int): | |
| self.total_steps = total_steps | |
| def on_log(self, args, state, control, logs=None, **kwargs): | |
| if logs is None: | |
| return | |
| step = state.global_step | |
| status.step = step | |
| if "loss" in logs: | |
| status.loss = logs["loss"] | |
| msg = f"Step {step}/{self.total_steps} β loss={logs.get('loss', 'n/a'):.4f}" | |
| status.append(msg) | |
| def on_train_end(self, args, state, control, **kwargs): | |
| status.append("β Training complete!") | |
| status.done = True | |
| # --------------------------------------------------------------------------- | |
| # Main training orchestrator | |
| # --------------------------------------------------------------------------- | |
| def run_training( | |
| file_obj, | |
| file_type: str, | |
| question_col: str, | |
| answer_col: str, | |
| prompt_template: str, | |
| vocab_size: int, | |
| hidden_size: int, | |
| num_layers: int, | |
| num_heads: int, | |
| max_length: int, | |
| num_epochs: int, | |
| batch_size: int, | |
| learning_rate: float, | |
| output_dir: str, | |
| progress=gr.Progress(), | |
| ): | |
| """ | |
| Main training entry-point used by Gradio. | |
| """ | |
| global status | |
| status = TrainingStatus() | |
| # --- 1. Load data --- | |
| status.append("π Loading dataβ¦") | |
| yield status.get_text(), None | |
| if file_obj is None: | |
| status.error = "No file uploaded." | |
| yield f"β Error: {status.error}", None | |
| return | |
| filepath = file_obj.name | |
| ext = Path(filepath).suffix.lower() | |
| if ext == ".txt": | |
| texts = load_text_from_txt(filepath) | |
| elif ext == ".csv": | |
| texts = load_qa_from_csv(filepath, question_col, answer_col, prompt_template) | |
| elif ext == ".json": | |
| texts = load_qa_from_json(filepath, question_col, answer_col, prompt_template) | |
| else: | |
| status.error = f"Unsupported file extension: {ext}" | |
| yield f"β Error: {status.error}", None | |
| return | |
| if len(texts) == 0: | |
| status.error = "No valid text samples found in file." | |
| yield f"β Error: {status.error}", None | |
| return | |
| status.append(f"β Loaded {len(texts)} text samples.") | |
| yield status.get_text(), None | |
| # --- 2. Train tokenizer --- | |
| status.append("π€ Training tokenizerβ¦") | |
| yield status.get_text(), None | |
| tokenizer_output = os.path.join(output_dir, "tokenizer") | |
| os.makedirs(tokenizer_output, exist_ok=True) | |
| tokenizer = train_custom_tokenizer(texts, vocab_size, tokenizer_output) | |
| status.append(f"β Tokenizer saved to {tokenizer_output}") | |
| yield status.get_text(), None | |
| # --- 3. Create model --- | |
| status.append("ποΈ Creating modelβ¦") | |
| yield status.get_text(), None | |
| model = create_model( | |
| vocab_size=tokenizer.vocab_size, | |
| hidden_size=hidden_size, | |
| num_layers=num_layers, | |
| num_heads=num_heads, | |
| max_length=max_length, | |
| ) | |
| status.append(f"β Model created: {num_layers} layers, {hidden_size} hidden, {num_heads} heads") | |
| yield status.get_text(), None | |
| # --- 4. Prepare dataset --- | |
| status.append("π Preparing datasetβ¦") | |
| yield status.get_text(), None | |
| dataset = Dataset.from_dict({"text": texts}) | |
| tokenized = tokenize_dataset(dataset, tokenizer, max_length) | |
| status.append(f"β Dataset tokenized: {len(tokenized)} samples") | |
| yield status.get_text(), None | |
| # --- 5. Train --- | |
| status.append(f"π Starting training ({num_epochs} epochs, lr={learning_rate})β¦") | |
| yield status.get_text(), None | |
| os.makedirs(output_dir, exist_ok=True) | |
| steps_per_epoch = max(1, len(tokenized) // batch_size) | |
| total_steps = steps_per_epoch * num_epochs | |
| training_args = TrainingArguments( | |
| output_dir=output_dir, | |
| num_train_epochs=num_epochs, | |
| per_device_train_batch_size=batch_size, | |
| learning_rate=learning_rate, | |
| weight_decay=0.01, | |
| logging_strategy="steps", | |
| logging_steps=max(1, total_steps // 20), | |
| save_strategy="epoch", | |
| save_total_limit=2, | |
| warmup_steps=max(1, total_steps // 10), | |
| fp16=False, | |
| dataloader_num_workers=0, | |
| disable_tqdm=True, | |
| logging_first_step=True, | |
| ) | |
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) | |
| # Build Trainer kwargs β detect whether 'processing_class' or 'tokenizer' is supported | |
| import inspect | |
| sig = inspect.signature(Trainer.__init__) | |
| trainer_kwargs = { | |
| "model": model, | |
| "args": training_args, | |
| "train_dataset": tokenized, | |
| "data_collator": data_collator, | |
| "callbacks": [StatusCallback(total_steps)], | |
| } | |
| if "processing_class" in sig.parameters: | |
| trainer_kwargs["processing_class"] = tokenizer | |
| elif "tokenizer" in sig.parameters: | |
| trainer_kwargs["tokenizer"] = tokenizer | |
| trainer = Trainer(**trainer_kwargs) | |
| trainer.train() | |
| # --- 6. Save everything --- | |
| status.append("πΎ Saving model & tokenizerβ¦") | |
| yield status.get_text(), None | |
| model.save_pretrained(os.path.join(output_dir, "model")) | |
| tokenizer.save_pretrained(os.path.join(output_dir, "tokenizer")) | |
| # Also save a combined README | |
| readme_path = os.path.join(output_dir, "README.md") | |
| with open(readme_path, "w", encoding="utf-8") as f: | |
| f.write(f"""# Trained Transformer Model | |
| ## Architecture | |
| - **Type:** GPT-2 causal language model | |
| - **Hidden size:** {hidden_size} | |
| - **Layers:** {num_layers} | |
| - **Attention heads:** {num_heads} | |
| - **Max sequence length:** {max_length} | |
| - **Vocab size:** {vocab_size} | |
| ## Training | |
| - **Epochs:** {num_epochs} | |
| - **Batch size:** {batch_size} | |
| - **Learning rate:** {learning_rate} | |
| - **Samples:** {len(texts)} | |
| ## Files | |
| - `model/` β model weights + config | |
| - `tokenizer/` β tokenizer vocab + config | |
| - `tokenizer/tokenizer.json` β raw tokenizer file | |
| ## Usage | |
| ```python | |
| from transformers import GPT2LMHeadModel, PreTrainedTokenizerFast | |
| model = GPT2LMHeadModel.from_pretrained("{output_dir}/model") | |
| tokenizer = PreTrainedTokenizerFast.from_pretrained("{output_dir}/tokenizer") | |
| ``` | |
| """) | |
| # Package as a zip for easy download | |
| zip_path = shutil.make_archive(output_dir, "zip", output_dir) | |
| status.append(f"β All done! Model saved to `{output_dir}`") | |
| status.append(f"π¦ Download zip: `{zip_path}`") | |
| status.done = True | |
| yield status.get_text(), zip_path | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI helpers | |
| # --------------------------------------------------------------------------- | |
| def update_ui_visibility(file_type: str): | |
| """Show/hide Q&A column inputs depending on file type.""" | |
| if file_type == "Plain text (.txt)": | |
| return [ | |
| gr.update(visible=False), # question_col | |
| gr.update(visible=False), # answer_col | |
| gr.update(visible=False), # prompt_template | |
| gr.update(placeholder="Upload a .txt file with raw text"), | |
| ] | |
| else: | |
| return [ | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(placeholder=f"Upload a {file_type.split('(')[1].replace(')', '')} file"), | |
| ] | |
| def auto_detect_cols(file_obj, file_type: str): | |
| """Auto-detect columns for CSV/JSON and return suggestions.""" | |
| if file_obj is None or file_type == "Plain text (.txt)": | |
| return "question", "answer" | |
| filepath = file_obj.name | |
| ext = Path(filepath).suffix.lower() | |
| try: | |
| if ext == ".csv": | |
| cols = detect_columns_csv(filepath) | |
| elif ext == ".json": | |
| cols = detect_columns_json(filepath) | |
| else: | |
| return "question", "answer" | |
| except Exception: | |
| return "question", "answer" | |
| # Simple heuristics | |
| q_col = next((c for c in cols if "question" in c.lower() or "q" == c.lower() or "prompt" in c.lower()), cols[0] if cols else "question") | |
| a_col = next((c for c in cols if "answer" in c.lower() or "a" == c.lower() or "response" in c.lower() or "output" in c.lower()), cols[1] if len(cols) > 1 else (cols[0] if cols else "answer")) | |
| return q_col, a_col | |
| # --------------------------------------------------------------------------- | |
| # Gradio App | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks(title="π§ Tiny Transformer Trainer") as demo: | |
| gr.Markdown(""" | |
| # π§ Tiny Transformer Trainer | |
| Upload your text data and train a small GPT model from scratch. | |
| Supports `.txt` (plain text), `.csv` and `.json` (Q&A pairs). | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π€ Data Upload") | |
| file_type = gr.Dropdown( | |
| choices=["Plain text (.txt)", "CSV Q&A pairs (.csv)", "JSON Q&A pairs (.json)"], | |
| value="Plain text (.txt)", | |
| label="Dataset type", | |
| ) | |
| file_input = gr.File(label="Upload file", type="filepath") | |
| question_col = gr.Textbox(value="question", label="Question/prompt column name", visible=False) | |
| answer_col = gr.Textbox(value="answer", label="Answer/response column name", visible=False) | |
| prompt_template = gr.Textbox( | |
| value="{question}\n{answer}", | |
| label="Prompt template (use {question} and {answer})", | |
| visible=False, | |
| ) | |
| auto_detect_btn = gr.Button("π Auto-detect columns", visible=False) | |
| gr.Markdown("---") | |
| gr.Markdown("### ποΈ Model Architecture") | |
| vocab_size = gr.Slider(1000, 32768, value=10000, step=1000, label="Vocabulary size") | |
| hidden_size = gr.Dropdown(choices=HIDDEN_SIZES, value=256, label="Hidden size (embedding dim)") | |
| num_layers = gr.Dropdown(choices=LAYER_COUNTS, value=4, label="Number of layers") | |
| num_heads = gr.Dropdown(choices=HEAD_COUNTS, value=4, label="Attention heads") | |
| max_length = gr.Dropdown(choices=MAX_SEQ_LENS, value=256, label="Max sequence length") | |
| gr.Markdown("---") | |
| gr.Markdown("### βοΈ Training Settings") | |
| num_epochs = gr.Slider(1, 20, value=3, step=1, label="Epochs") | |
| batch_size = gr.Slider(1, 32, value=8, step=1, label="Batch size") | |
| learning_rate = gr.Number(value=5e-4, label="Learning rate") | |
| output_dir = gr.Textbox(value=DEFAULT_OUTPUT_DIR, label="Output directory") | |
| train_btn = gr.Button("π Start Training", variant="primary") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Training Log") | |
| log_box = gr.Textbox(label="", lines=25, interactive=False, show_label=False) | |
| zip_download = gr.File(label="π¦ Download trained model (zip)", visible=True) | |
| # ------------------------------------------------------------------- | |
| # Event wiring | |
| # ------------------------------------------------------------------- | |
| def on_file_type_change(ft): | |
| return update_ui_visibility(ft) | |
| file_type.change( | |
| on_file_type_change, | |
| inputs=[file_type], | |
| outputs=[question_col, answer_col, prompt_template, file_input], | |
| ) | |
| # Also toggle auto-detect button visibility | |
| file_type.change( | |
| lambda ft: gr.update(visible=(ft != "Plain text (.txt)")), | |
| inputs=[file_type], | |
| outputs=[auto_detect_btn], | |
| ) | |
| def on_auto_detect(file_obj, ft): | |
| q, a = auto_detect_cols(file_obj, ft) | |
| return q, a | |
| auto_detect_btn.click( | |
| on_auto_detect, | |
| inputs=[file_input, file_type], | |
| outputs=[question_col, answer_col], | |
| ) | |
| train_btn.click( | |
| run_training, | |
| inputs=[ | |
| file_input, file_type, question_col, answer_col, prompt_template, | |
| vocab_size, hidden_size, num_layers, num_heads, max_length, | |
| num_epochs, batch_size, learning_rate, output_dir, | |
| ], | |
| outputs=[log_box, zip_download], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False, theme=gr.themes.Soft()) | |