Spaces:
Sleeping
Sleeping
File size: 1,950 Bytes
cf19834 8a650f0 cf19834 3f48220 6f9ff60 3f48220 cf19834 8a650f0 3f48220 cf19834 3f48220 8a650f0 cf19834 80771ae 8a650f0 cf19834 8a650f0 cf19834 8a650f0 80771ae 8a650f0 3f48220 8a650f0 3f48220 8a650f0 cf19834 3f48220 8a650f0 cf19834 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import json
import os
import zipfile
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
MODEL_NAME = "distilgpt2"
OUTPUT_DIR = "trained_model"
MAX_LENGTH = 128
# Load dataset from local JSONL
dataset = load_dataset("json", data_files={"train": "sample_dataset.jsonl"})
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
def tokenize_function(examples):
texts = []
for p, c in zip(examples["prompt"], examples["completion"]):
# Convert dict completion to string
if isinstance(c, dict):
c_str = f"Q: {c.get('question','')} Options: {c.get('options',[])} Answer: {c.get('answer','')} Explanation: {c.get('explanation','')}"
else:
c_str = str(c)
texts.append(p + " " + c_str)
tokens = tokenizer(texts, truncation=True, padding="max_length", max_length=MAX_LENGTH)
tokens["labels"] = tokens["input_ids"].copy()
return tokens
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
overwrite_output_dir=True,
per_device_train_batch_size=2,
num_train_epochs=1,
logging_steps=1,
save_strategy="epoch",
learning_rate=2e-5,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
)
trainer.train()
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
# Zip the trained model for download
zip_filename = "trained_model.zip"
with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(OUTPUT_DIR):
for file in files:
filepath = os.path.join(root, file)
zipf.write(filepath, os.path.relpath(filepath, OUTPUT_DIR)) |