amine-araich commited on
Commit
f07479b
·
verified ·
1 Parent(s): 8eab63f

Added more epochs for training the model

Browse files
Files changed (1) hide show
  1. model.py +70 -70
model.py CHANGED
@@ -1,70 +1,70 @@
1
- import torch
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
- from transformers import Trainer, TrainingArguments, GenerationConfig
4
-
5
-
6
- def load_model(model_name="facebook/bart-large-cnn"):
7
- """
8
- Load a pre-trained summarization model
9
- Options: facebook/bart-large-cnn, google/pegasus-xsum, sshleifer/distilbart-cnn-12-6
10
- """
11
- tokenizer = AutoTokenizer.from_pretrained(model_name)
12
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
13
- return model, tokenizer
14
-
15
-
16
- def fine_tune_model(model, tokenizer, dataset, output_dir="./summarization_model"):
17
- """Fine-tune model on prepared dataset"""
18
- training_args = TrainingArguments(
19
- output_dir=output_dir,
20
- per_device_train_batch_size=4,
21
- per_device_eval_batch_size=4,
22
- gradient_accumulation_steps=4,
23
- learning_rate=5e-5,
24
- num_train_epochs=6,
25
- save_strategy="epoch",
26
- eval_strategy="epoch",
27
- load_best_model_at_end=True,
28
- report_to="none",
29
- )
30
-
31
- trainer = Trainer(
32
- model=model,
33
- args=training_args,
34
- train_dataset=dataset["train"],
35
- eval_dataset=dataset["validation"],
36
- tokenizer=tokenizer,
37
- )
38
-
39
- trainer.train()
40
-
41
- tokenizer.save_pretrained(output_dir)
42
- model.save_pretrained(output_dir)
43
-
44
- return model
45
-
46
-
47
- def generate_stylized_summary(text, model, tokenizer, style="formal", max_length=150):
48
- """Generate a summary in the specified style"""
49
- # Prepend style token to input
50
- styled_input = f"[{style.upper()}] {text}"
51
- inputs = tokenizer(
52
- styled_input, return_tensors="pt", max_length=1024, truncation=True
53
- )
54
-
55
- generation_config = GenerationConfig(
56
- max_length=max_length,
57
- min_length=56,
58
- early_stopping=True,
59
- num_beams=4,
60
- length_penalty=2.0,
61
- no_repeat_ngram_size=3,
62
- forced_bos_token_id=0,
63
- )
64
-
65
- summary_ids = model.generate(
66
- inputs["input_ids"], generation_config=generation_config
67
- )
68
-
69
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
70
- return summary
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ from transformers import Trainer, TrainingArguments, GenerationConfig
4
+
5
+
6
+ def load_model(model_name="facebook/bart-large-cnn"):
7
+ """
8
+ Load a pre-trained summarization model
9
+ Options: facebook/bart-large-cnn, google/pegasus-xsum, sshleifer/distilbart-cnn-12-6
10
+ """
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
13
+ return model, tokenizer
14
+
15
+
16
+ def fine_tune_model(model, tokenizer, dataset, output_dir="./summarization_model"):
17
+ """Fine-tune model on prepared dataset"""
18
+ training_args = TrainingArguments(
19
+ output_dir=output_dir,
20
+ per_device_train_batch_size=4,
21
+ per_device_eval_batch_size=4,
22
+ gradient_accumulation_steps=4,
23
+ learning_rate=5e-5,
24
+ num_train_epochs=20,
25
+ save_strategy="epoch",
26
+ eval_strategy="epoch",
27
+ load_best_model_at_end=True,
28
+ report_to="none",
29
+ )
30
+
31
+ trainer = Trainer(
32
+ model=model,
33
+ args=training_args,
34
+ train_dataset=dataset["train"],
35
+ eval_dataset=dataset["validation"],
36
+ tokenizer=tokenizer,
37
+ )
38
+
39
+ trainer.train()
40
+
41
+ tokenizer.save_pretrained(output_dir)
42
+ model.save_pretrained(output_dir)
43
+
44
+ return model
45
+
46
+
47
+ def generate_stylized_summary(text, model, tokenizer, style="formal", max_length=150):
48
+ """Generate a summary in the specified style"""
49
+ # Prepend style token to input
50
+ styled_input = f"[{style.upper()}] {text}"
51
+ inputs = tokenizer(
52
+ styled_input, return_tensors="pt", max_length=1024, truncation=True
53
+ )
54
+
55
+ generation_config = GenerationConfig(
56
+ max_length=max_length,
57
+ min_length=56,
58
+ early_stopping=True,
59
+ num_beams=4,
60
+ length_penalty=2.0,
61
+ no_repeat_ngram_size=3,
62
+ forced_bos_token_id=0,
63
+ )
64
+
65
+ summary_ids = model.generate(
66
+ inputs["input_ids"], generation_config=generation_config
67
+ )
68
+
69
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
70
+ return summary