K00B404 commited on
Commit
4f117ff
·
verified ·
1 Parent(s): e79fa01

Create app3.py

Browse files
Files changed (1) hide show
  1. app3.py +68 -0
app3.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments
3
+ from datasets import load_dataset
4
+ import numpy as np
5
+ import torch
6
+
7
+ # Load GPT2 Model and Tokenizer
8
+ model_name = "gpt2"
9
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
10
+ model = GPT2LMHeadModel.from_pretrained(model_name)
11
+
12
+ # Define PPO Training Function (simplified)
13
+ def fine_tune_gpt2_with_ppo(dataset_name, epochs, learning_rate):
14
+ # Load the dataset
15
+ dataset = load_dataset(dataset_name)
16
+
17
+ # Prepare dataset for GPT-2 training
18
+ def encode(examples):
19
+ return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=128)
20
+
21
+ tokenized_dataset = dataset.map(encode, batched=True)
22
+ train_dataset = tokenized_dataset["train"]
23
+
24
+ # Prepare data collator and training arguments
25
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
26
+ training_args = TrainingArguments(
27
+ output_dir="./results",
28
+ overwrite_output_dir=True,
29
+ num_train_epochs=epochs,
30
+ per_device_train_batch_size=4,
31
+ save_steps=10_000,
32
+ save_total_limit=2,
33
+ learning_rate=learning_rate
34
+ )
35
+
36
+ # Trainer
37
+ trainer = Trainer(
38
+ model=model,
39
+ args=training_args,
40
+ data_collator=data_collator,
41
+ train_dataset=train_dataset
42
+ )
43
+
44
+ # Train model
45
+ trainer.train()
46
+
47
+ return "Training Completed!"
48
+
49
+ # Gradio Interface
50
+ def train_interface(dataset, epochs, learning_rate):
51
+ result = fine_tune_gpt2_with_ppo(dataset, int(epochs), float(learning_rate))
52
+ return result
53
+
54
+ # Gradio App
55
+ gradio_interface = gr.Interface(
56
+ fn=train_interface,
57
+ inputs=[
58
+ gr.inputs.Textbox(label="Dataset (e.g. 'wikitext')"),
59
+ gr.inputs.Slider(1, 10, step=1, label="Epochs"),
60
+ gr.inputs.Textbox(label="Learning Rate")
61
+ ],
62
+ outputs="text",
63
+ title="GPT-2 RL Training App",
64
+ description="Fine-tune GPT-2 using PPO via a Gradio interface."
65
+ )
66
+
67
+ # Launch the app
68
+ gradio_interface.launch()