meow / app.py
soalwin's picture
Update app.py
81d071c verified
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments, pipeline
import gradio as gr
# Load the tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# Function to load and tokenize dataset
def load_dataset(file_path, tokenizer, block_size=128):
try:
dataset = TextDataset(
tokenizer=tokenizer,
file_path=file_path,
block_size=block_size
)
return dataset
except Exception as e:
print(f"Error loading dataset: {e}")
return None
# Path to your custom dataset
file_path = "https://huggingface.co/spaces/soalwin/meow/resolve/main/gpt2trainmodel.txt"
# Load and tokenize the dataset
dataset = load_dataset(file_path, tokenizer)
if dataset is None:
raise ValueError("Failed to load dataset. Please check the file path and format.")
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
# Load the pre-trained GPT-2 model
model = GPT2LMHeadModel.from_pretrained('gpt2')
# Set up training arguments
training_args = TrainingArguments(
output_dir='./results', # output directory
overwrite_output_dir=True, # overwrite the content of the output directory
num_train_epochs=3, # number of training epochs
per_device_train_batch_size=4, # batch size for training
save_steps=10_000, # save checkpoint every 10,000 steps
save_total_limit=2, # only last 2 checkpoints are saved
)
# Initialize the Trainer
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=dataset,
)
# Start training
trainer.train()
# Save the fine-tuned model
model.save_pretrained('./fine-tuned-gpt2')
tokenizer.save_pretrained('./fine-tuned-gpt2')
# Function to generate text using the fine-tuned model
def generate_text(prompt):
# Load the fine-tuned model and tokenizer
model = GPT2LMHeadModel.from_pretrained('./fine-tuned-gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('./fine-tuned-gpt2')
# Set up the text generation pipeline
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
# Generate text based on a prompt
output = generator(prompt, max_length=100, num_return_sequences=1)
return output[0]['generated_text']
# Create a Gradio interface
iface = gr.Interface(fn=generate_text, inputs="text", outputs="text")
# Launch the interface
iface.launch()