student2222333051 commited on
Commit
5bed521
·
verified ·
1 Parent(s): 41edee5

Update summarizer.py

Browse files
Files changed (1) hide show
  1. summarizer.py +93 -11
summarizer.py CHANGED
@@ -1,17 +1,99 @@
 
 
 
 
1
  from transformers import BartTokenizer, BartForConditionalGeneration
2
 
3
- tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
4
- model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  def generate_summary(text: str) -> str:
7
- inputs = tokenizer([text], max_length=1024, truncation=True, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- summary_ids = model.generate(
10
- inputs["input_ids"],
11
- num_beams=4,
12
- min_length=40,
13
- max_length=200,
14
- early_stopping=True
15
- )
 
 
 
 
 
 
 
 
 
 
16
 
17
- return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # summarizer.py
2
+ import os
3
+ import math
4
+ import torch
5
  from transformers import BartTokenizer, BartForConditionalGeneration
6
 
7
+ # Конфигурация: fine-tuned модель атауы немесе default
8
+ MODEL_NAME = os.getenv("FINE_TUNED_MODEL", "facebook/bart-large-cnn")
9
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ # Инициализация (бір рет)
12
+ tokenizer = BartTokenizer.from_pretrained(MODEL_NAME)
13
+ model = BartForConditionalGeneration.from_pretrained(MODEL_NAME).to(DEVICE)
14
+ model.eval()
15
+
16
+ # Параметрлер
17
+ MAX_INPUT_LENGTH = 1024
18
+ SUMMARY_MIN_LENGTH = 40
19
+ SUMMARY_MAX_LENGTH = 200
20
+ NUM_BEAMS = 4
21
+
22
+ def chunk_text(text: str, max_tokens: int = MAX_INPUT_LENGTH, overlap: int = 50):
23
+ """
24
+ Ұзын мәтінді токендер бойынша бөліп қайтару. overlap — әр кусок арасында қайталанатын токен саны.
25
+ """
26
+ inputs = tokenizer(text, return_tensors="pt", truncation=False)
27
+ input_ids = inputs["input_ids"][0].tolist()
28
+ chunks = []
29
+ start = 0
30
+ while start < len(input_ids):
31
+ end = start + max_tokens
32
+ chunk_ids = input_ids[start:end]
33
+ chunk_text = tokenizer.decode(chunk_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
34
+ chunks.append(chunk_text)
35
+ if end >= len(input_ids):
36
+ break
37
+ start = end - overlap
38
+ return chunks
39
 
40
  def generate_summary(text: str) -> str:
41
+ """
42
+ Егер мәтін MAX_INPUT_LENGTH-тен ұзын болса — бөліп, әр бөліктің summary алып,
43
+ содан кейін қысқа unified summary қайтару.
44
+ """
45
+ text = text.strip()
46
+ if not text:
47
+ return ""
48
+
49
+ # Егер қысқа — тікелей summary
50
+ tokens = tokenizer(text, max_length=1, truncation=False)
51
+ # Қарапайым жүктеме: егер мәтін қысқа — бір шақыру
52
+ if len(tokenizer.encode(text)) <= MAX_INPUT_LENGTH:
53
+ inputs = tokenizer([text], max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors="pt").to(DEVICE)
54
+ with torch.no_grad():
55
+ summary_ids = model.generate(
56
+ inputs["input_ids"],
57
+ attention_mask=inputs.get("attention_mask", None),
58
+ num_beams=NUM_BEAMS,
59
+ min_length=SUMMARY_MIN_LENGTH,
60
+ max_length=SUMMARY_MAX_LENGTH,
61
+ early_stopping=True,
62
+ no_repeat_ngram_size=3
63
+ )
64
+ return tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
65
 
66
+ # Ұзын мәтін: бөліп, әр бөлімнің summary алып, содан кейін агрегаттау
67
+ chunks = chunk_text(text, max_tokens=MAX_INPUT_LENGTH, overlap=64)
68
+ partial_summaries = []
69
+ for chunk in chunks:
70
+ inputs = tokenizer([chunk], max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors="pt").to(DEVICE)
71
+ with torch.no_grad():
72
+ summary_ids = model.generate(
73
+ inputs["input_ids"],
74
+ attention_mask=inputs.get("attention_mask", None),
75
+ num_beams=NUM_BEAMS,
76
+ min_length=SUMMARY_MIN_LENGTH // 2,
77
+ max_length=SUMMARY_MAX_LENGTH,
78
+ early_stopping=True,
79
+ no_repeat_ngram_size=3
80
+ )
81
+ s = tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
82
+ partial_summaries.append(s)
83
 
84
+ # Біріктіру: partial_summaries-тан соңғы қысқаша summary жасау
85
+ combined = "\n\n".join(partial_summaries)
86
+ # Егер combined тым ұзын болса — қысқаша summary
87
+ inputs = tokenizer([combined], max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors="pt").to(DEVICE)
88
+ with torch.no_grad():
89
+ summary_ids = model.generate(
90
+ inputs["input_ids"],
91
+ attention_mask=inputs.get("attention_mask", None),
92
+ num_beams=NUM_BEAMS,
93
+ min_length=SUMMARY_MIN_LENGTH,
94
+ max_length=SUMMARY_MAX_LENGTH,
95
+ early_stopping=True,
96
+ no_repeat_ngram_size=3
97
+ )
98
+ final_summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
99
+ return final_summary