Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import zipfile | |
| from datasets import load_dataset | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments | |
| MODEL_NAME = "distilgpt2" | |
| OUTPUT_DIR = "trained_model" | |
| MAX_LENGTH = 128 | |
| # Load dataset from local JSONL | |
| dataset = load_dataset("json", data_files={"train": "sample_dataset.jsonl"}) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| def tokenize_function(examples): | |
| texts = [] | |
| for p, c in zip(examples["prompt"], examples["completion"]): | |
| # Convert dict completion to string | |
| if isinstance(c, dict): | |
| c_str = f"Q: {c.get('question','')} Options: {c.get('options',[])} Answer: {c.get('answer','')} Explanation: {c.get('explanation','')}" | |
| else: | |
| c_str = str(c) | |
| texts.append(p + " " + c_str) | |
| tokens = tokenizer(texts, truncation=True, padding="max_length", max_length=MAX_LENGTH) | |
| tokens["labels"] = tokens["input_ids"].copy() | |
| return tokens | |
| tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) | |
| training_args = TrainingArguments( | |
| output_dir=OUTPUT_DIR, | |
| overwrite_output_dir=True, | |
| per_device_train_batch_size=2, | |
| num_train_epochs=1, | |
| logging_steps=1, | |
| save_strategy="epoch", | |
| learning_rate=2e-5, | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized_dataset["train"], | |
| ) | |
| trainer.train() | |
| trainer.save_model(OUTPUT_DIR) | |
| tokenizer.save_pretrained(OUTPUT_DIR) | |
| # Zip the trained model for download | |
| zip_filename = "trained_model.zip" | |
| with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf: | |
| for root, dirs, files in os.walk(OUTPUT_DIR): | |
| for file in files: | |
| filepath = os.path.join(root, file) | |
| zipf.write(filepath, os.path.relpath(filepath, OUTPUT_DIR)) |