pangxiang commited on
Commit
80b3a89
·
verified ·
1 Parent(s): 0d29db9

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +23 -25
train.py CHANGED
@@ -1,29 +1,27 @@
1
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
2
- from datasets import Dataset
3
  import json
 
 
4
 
5
- with open('train_data.json', 'r') as f:
6
- data = json.load(f)
 
7
 
8
- texts = []
9
- labels = []
10
- for label, samples in data.items():
11
- for text in samples:
12
- texts.append(text)
13
- labels.append(label)
 
 
 
 
 
 
 
 
 
 
14
 
15
- dataset = Dataset.from_dict({"text": texts, "label": labels})
16
-
17
- model_name = "prajjwal1/bert-tiny"
18
- tokenizer = AutoTokenizer.from_pretrained(model_name)
19
- model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
20
-
21
- def tokenize_function(examples):
22
- return tokenizer(examples["text"], padding="max_length", truncation=True)
23
-
24
- tokenized_datasets = dataset.map(tokenize_function, batched=True)
25
-
26
- training_args = TrainingArguments(output_dir="results", num_train_epochs=2)
27
- trainer = Trainer(model=model, args=training_args, train_dataset=tokenized_datasets)
28
- trainer.train()
29
- model.save_pretrained("trained_model")
 
 
 
1
  import json
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
4
 
5
+ def load_training_data():
6
+ with open('train_data.json', 'r', encoding='utf-8') as f:
7
+ return json.load(f)
8
 
9
+ def train_model():
10
+ # 加载数据
11
+ data = load_training_data()
12
+
13
+ # 初始化tokenizer和模型(使用小模型)
14
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
15
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
16
+
17
+ # 数据处理和训练逻辑
18
+ # ... 这里添加你的训练代码
19
+
20
+ # 保存模型
21
+ model.save_pretrained("./trained_model")
22
+ tokenizer.save_pretrained("./trained_model")
23
+
24
+ print("模型训练完成!")
25
 
26
+ if __name__ == "__main__":
27
+ train_model()