student2222333051 commited on
Commit
fbd3dd8
·
verified ·
1 Parent(s): 0a5b809

Create ine_tune.py

Browse files
Files changed (1) hide show
  1. ine_tune.py +76 -0
ine_tune.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # fine_tune.py
2
+ from datasets import load_dataset, load_metric
3
+ from transformers import BartTokenizer, BartForConditionalGeneration, Trainer, TrainingArguments
4
+
5
+ # 1️⃣ Деректерді жүктеу (ArXiv)
6
+ dataset = load_dataset("scientific_papers", "arxiv")
7
+
8
+ # Шағын subset (тест үшін)
9
+ dataset["train"] = dataset["train"].select(range(1000))
10
+ dataset["validation"] = dataset["validation"].select(range(200))
11
+
12
+ # 2️⃣ Tokenizer және модель
13
+ model_name = "facebook/bart-large-cnn"
14
+ tokenizer = BartTokenizer.from_pretrained(model_name)
15
+ model = BartForConditionalGeneration.from_pretrained(model_name)
16
+
17
+ max_input_length = 1024
18
+ max_output_length = 200
19
+
20
+ # 3️⃣ Tokenization
21
+ def preprocess_function(batch):
22
+ inputs = tokenizer(batch["article"], max_length=max_input_length, truncation=True)
23
+ outputs = tokenizer(batch["abstract"], max_length=max_output_length, truncation=True)
24
+ batch["input_ids"] = inputs["input_ids"]
25
+ batch["attention_mask"] = inputs["attention_mask"]
26
+ batch["labels"] = outputs["input_ids"]
27
+ return batch
28
+
29
+ tokenized_train = dataset["train"].map(preprocess_function, batched=True)
30
+ tokenized_val = dataset["validation"].map(preprocess_function, batched=True)
31
+
32
+ # 4️⃣ ROUGE метрика
33
+ rouge = load_metric("rouge")
34
+
35
+ def compute_metrics(eval_pred):
36
+ predictions, labels = eval_pred
37
+ decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
38
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
39
+ result = rouge.compute(predictions=decoded_preds, references=decoded_labels)
40
+ return {key: value.mid.fmeasure * 100 for key, value in result.items()}
41
+
42
+ # 5️⃣ TrainingArguments
43
+ training_args = TrainingArguments(
44
+ output_dir="./bart-finetuned-arxiv",
45
+ evaluation_strategy="steps",
46
+ eval_steps=500,
47
+ save_steps=500,
48
+ save_total_limit=2,
49
+ learning_rate=3e-5,
50
+ per_device_train_batch_size=2,
51
+ per_device_eval_batch_size=2,
52
+ num_train_epochs=3,
53
+ weight_decay=0.01,
54
+ fp16=True,
55
+ logging_dir="./logs",
56
+ logging_steps=100,
57
+ )
58
+
59
+ # 6️⃣ Trainer
60
+ trainer = Trainer(
61
+ model=model,
62
+ args=training_args,
63
+ train_dataset=tokenized_train,
64
+ eval_dataset=tokenized_val,
65
+ tokenizer=tokenizer,
66
+ compute_metrics=compute_metrics,
67
+ )
68
+
69
+ # 7️⃣ Fine-tune бастау
70
+ trainer.train()
71
+
72
+ # 8️⃣ Модельді сақтау
73
+ model.save_pretrained("./bart-finetuned-arxiv")
74
+ tokenizer.save_pretrained("./bart-finetuned-arxiv")
75
+
76
+ print("Fine-tuning аяқталды! Модель сақталды ./bart-finetuned-arxiv")