File size: 1,926 Bytes
fbd3dd8
887395f
fbd3dd8
887395f
 
 
 
 
fbd3dd8
1f098c8
fbd3dd8
887395f
fbd3dd8
 
 
 
 
 
 
 
1f098c8
fbd3dd8
 
 
 
 
887395f
 
fbd3dd8
 
1f098c8
fbd3dd8
 
 
 
 
 
 
 
 
887395f
fbd3dd8
 
887395f
fbd3dd8
 
 
 
 
 
 
 
 
 
 
1f098c8
887395f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# fine_tune.py
from datasets import load_dataset
from transformers import BartTokenizer, BartForConditionalGeneration, Trainer, TrainingArguments
import os

model_name = "facebook/bart-large-cnn"
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)

dataset = load_dataset("marcov/scientific_papers_arxiv_promptsource")

# Күнделікті тест үшін шағын subset
dataset["train"] = dataset["train"].select(range(1000))
dataset["validation"] = dataset["validation"].select(range(200))

max_input_length = 1024
max_output_length = 200

def preprocess_function(batch):
    inputs = tokenizer(batch["article"], max_length=max_input_length, truncation=True)
    outputs = tokenizer(batch["summary"], max_length=max_output_length, truncation=True)
    batch["input_ids"] = inputs["input_ids"]
    batch["attention_mask"] = inputs["attention_mask"]
    batch["labels"] = outputs["input_ids"]
    return batch

tokenized_train = dataset["train"].map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names)
tokenized_val = dataset["validation"].map(preprocess_function, batched=True, remove_columns=dataset["validation"].column_names)

training_args = TrainingArguments(
    output_dir="./bart-finetuned-arxiv-hub",
    evaluation_strategy="steps",
    eval_steps=500,
    save_steps=500,
    save_total_limit=2,
    learning_rate=3e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=3,
    weight_decay=0.01,
    fp16=False,  # GPU болса True қой
    logging_dir="./logs",
    logging_steps=100,
    push_to_hub=True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    tokenizer=tokenizer,
)

trainer.train()
trainer.push_to_hub("username/bart-finetuned-arxiv")
print("Fine-tuning complete.")