Perfectyash commited on
Commit
5fc5bfd
·
verified ·
1 Parent(s): 3e428f5

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +52 -0
train.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import (
3
+ DistilBertTokenizerFast,
4
+ DistilBertForSequenceClassification,
5
+ Trainer,
6
+ TrainingArguments
7
+ )
8
+ import pandas as pd
9
+
10
+ # Load dataset
11
+ df = pd.read_csv("data.csv")
12
+ dataset = load_dataset("csv", data_files="data.csv")
13
+
14
+ # Label mapping
15
+ label_map = {"Low Risk": 0, "Medium Risk": 1, "High Risk": 2}
16
+ df["label"] = df["label"].map(label_map)
17
+
18
+ dataset = load_dataset("csv", data_files={"train": "data.csv"})
19
+
20
+ # Tokenizer
21
+ tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
22
+
23
+ def tokenize(batch):
24
+ return tokenizer(batch["text"], padding=True, truncation=True)
25
+
26
+ dataset = dataset.map(tokenize, batched=True)
27
+
28
+ # Model
29
+ model = DistilBertForSequenceClassification.from_pretrained(
30
+ "distilbert-base-uncased",
31
+ num_labels=3
32
+ )
33
+
34
+ # Training args
35
+ training_args = TrainingArguments(
36
+ output_dir="./results",
37
+ evaluation_strategy="no",
38
+ per_device_train_batch_size=4,
39
+ num_train_epochs=3,
40
+ save_strategy="epoch",
41
+ logging_dir="./logs"
42
+ )
43
+
44
+ trainer = Trainer(
45
+ model=model,
46
+ args=training_args,
47
+ train_dataset=dataset["train"]
48
+ )
49
+
50
+ trainer.train()
51
+ model.save_pretrained("./model")
52
+ tokenizer.save_pretrained("./model")