Dorn4449 commited on
Commit
a30b42a
·
verified ·
1 Parent(s): 7590662

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
3
+ from datasets import load_dataset
4
+
5
+ # Define datasets and their IDs
6
+ datasets_info = {
7
+ "SQuAD": "squad",
8
+ "SQuAD 2.0": "squad_v2",
9
+ "Natural Questions": "nq",
10
+ "TriviaQA": "triviaqa",
11
+ "QuAC": "quac",
12
+ "FAQ Dataset": "faq",
13
+ "BoolQ": "boolq",
14
+ "Open Book QA": "obqa"
15
+ }
16
+
17
+ # Load model and tokenizer directly
18
+ tokenizer = AutoTokenizer.from_pretrained("nvidia/Llama-3.1-Nemotron-70B-Instruct-HF")
19
+ model = AutoModelForCausalLM.from_pretrained("nvidia/Llama-3.1-Nemotron-70B-Instruct-HF")
20
+
21
+ def train_model(dataset_name):
22
+ # Load the dataset
23
+ dataset = load_dataset(datasets_info[dataset_name])
24
+
25
+ # Tokenization
26
+ def preprocess_function(examples):
27
+ return tokenizer(examples['question'], examples['context'], truncation=True)
28
+
29
+ tokenized_dataset = dataset.map(preprocess_function, batched=True)
30
+
31
+ # Fine-tune the model
32
+ training_args = TrainingArguments(
33
+ output_dir=f"./{dataset_name}_model",
34
+ evaluation_strategy="epoch",
35
+ learning_rate=2e-5,
36
+ per_device_train_batch_size=8,
37
+ per_device_eval_batch_size=8,
38
+ num_train_epochs=3,
39
+ weight_decay=0.01,
40
+ logging_dir='./logs',
41
+ )
42
+
43
+ trainer = Trainer(
44
+ model=model,
45
+ args=training_args,
46
+ train_dataset=tokenized_dataset['train'],
47
+ eval_dataset=tokenized_dataset['validation']
48
+ )
49
+
50
+ trainer.train()
51
+
52
+ # Save the model weights
53
+ model.save_pretrained(f"./{dataset_name}_model")
54
+ tokenizer.save_pretrained(f"./{dataset_name}_model")
55
+
56
+ return f"Model trained and saved for {dataset_name}!"
57
+
58
+ # Gradio Interface
59
+ with gr.Blocks() as demo:
60
+ gr.Markdown("## Train QA Model on Multiple Datasets")
61
+ dataset_name = gr.Dropdown(choices=list(datasets_info.keys()), label="Select Dataset")
62
+ train_button = gr.Button("Train Model")
63
+ output = gr.Textbox(label="Output")
64
+
65
+ def train_and_display(dataset_name):
66
+ return train_model(dataset_name)
67
+
68
+ train_button.click(train_and_display, inputs=dataset_name, outputs=output)
69
+
70
+ demo.launch()