import torch from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments, pipeline import gradio as gr # Load the tokenizer tokenizer = GPT2Tokenizer.from_pretrained('gpt2') # Function to load and tokenize dataset def load_dataset(file_path, tokenizer, block_size=128): try: dataset = TextDataset( tokenizer=tokenizer, file_path=file_path, block_size=block_size ) return dataset except Exception as e: print(f"Error loading dataset: {e}") return None # Path to your custom dataset file_path = "https://huggingface.co/spaces/soalwin/meow/resolve/main/gpt2trainmodel.txt" # Load and tokenize the dataset dataset = load_dataset(file_path, tokenizer) if dataset is None: raise ValueError("Failed to load dataset. Please check the file path and format.") data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False ) # Load the pre-trained GPT-2 model model = GPT2LMHeadModel.from_pretrained('gpt2') # Set up training arguments training_args = TrainingArguments( output_dir='./results', # output directory overwrite_output_dir=True, # overwrite the content of the output directory num_train_epochs=3, # number of training epochs per_device_train_batch_size=4, # batch size for training save_steps=10_000, # save checkpoint every 10,000 steps save_total_limit=2, # only last 2 checkpoints are saved ) # Initialize the Trainer trainer = Trainer( model=model, args=training_args, data_collator=data_collator, train_dataset=dataset, ) # Start training trainer.train() # Save the fine-tuned model model.save_pretrained('./fine-tuned-gpt2') tokenizer.save_pretrained('./fine-tuned-gpt2') # Function to generate text using the fine-tuned model def generate_text(prompt): # Load the fine-tuned model and tokenizer model = GPT2LMHeadModel.from_pretrained('./fine-tuned-gpt2') tokenizer = GPT2Tokenizer.from_pretrained('./fine-tuned-gpt2') # Set up the text generation pipeline generator = pipeline('text-generation', model=model, tokenizer=tokenizer) # Generate text based on a prompt output = generator(prompt, max_length=100, num_return_sequences=1) return output[0]['generated_text'] # Create a Gradio interface iface = gr.Interface(fn=generate_text, inputs="text", outputs="text") # Launch the interface iface.launch()