File size: 4,337 Bytes
5bed521 6de6b50 e169923 5bed521 44f67db 486dbd6 5bed521 486dbd6 5bed521 486dbd6 5bed521 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
# summarizer.py
import os
import math
import torch
from transformers import BartTokenizer, BartForConditionalGeneration
# Конфигурация: fine-tuned модель атауы немесе default
MODEL_NAME = os.getenv("FINE_TUNED_MODEL", "facebook/bart-large-cnn")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Инициализация (бір рет)
tokenizer = BartTokenizer.from_pretrained(MODEL_NAME)
model = BartForConditionalGeneration.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()
# Параметрлер
MAX_INPUT_LENGTH = 1024
SUMMARY_MIN_LENGTH = 40
SUMMARY_MAX_LENGTH = 200
NUM_BEAMS = 4
def chunk_text(text: str, max_tokens: int = MAX_INPUT_LENGTH, overlap: int = 50):
"""
Ұзын мәтінді токендер бойынша бөліп қайтару. overlap — әр кусок арасында қайталанатын токен саны.
"""
inputs = tokenizer(text, return_tensors="pt", truncation=False)
input_ids = inputs["input_ids"][0].tolist()
chunks = []
start = 0
while start < len(input_ids):
end = start + max_tokens
chunk_ids = input_ids[start:end]
chunk_text = tokenizer.decode(chunk_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
chunks.append(chunk_text)
if end >= len(input_ids):
break
start = end - overlap
return chunks
def generate_summary(text: str) -> str:
"""
Егер мәтін MAX_INPUT_LENGTH-тен ұзын болса — бөліп, әр бөліктің summary алып,
содан кейін қысқа unified summary қайтару.
"""
text = text.strip()
if not text:
return ""
# Егер қысқа — тікелей summary
tokens = tokenizer(text, max_length=1, truncation=False)
# Қарапайым жүктеме: егер мәтін қысқа — бір шақыру
if len(tokenizer.encode(text)) <= MAX_INPUT_LENGTH:
inputs = tokenizer([text], max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors="pt").to(DEVICE)
with torch.no_grad():
summary_ids = model.generate(
inputs["input_ids"],
attention_mask=inputs.get("attention_mask", None),
num_beams=NUM_BEAMS,
min_length=SUMMARY_MIN_LENGTH,
max_length=SUMMARY_MAX_LENGTH,
early_stopping=True,
no_repeat_ngram_size=3
)
return tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
# Ұзын мәтін: бөліп, әр бөлімнің summary алып, содан кейін агрегаттау
chunks = chunk_text(text, max_tokens=MAX_INPUT_LENGTH, overlap=64)
partial_summaries = []
for chunk in chunks:
inputs = tokenizer([chunk], max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors="pt").to(DEVICE)
with torch.no_grad():
summary_ids = model.generate(
inputs["input_ids"],
attention_mask=inputs.get("attention_mask", None),
num_beams=NUM_BEAMS,
min_length=SUMMARY_MIN_LENGTH // 2,
max_length=SUMMARY_MAX_LENGTH,
early_stopping=True,
no_repeat_ngram_size=3
)
s = tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
partial_summaries.append(s)
# Біріктіру: partial_summaries-тан соңғы қысқаша summary жасау
combined = "\n\n".join(partial_summaries)
# Егер combined тым ұзын болса — қысқаша summary
inputs = tokenizer([combined], max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors="pt").to(DEVICE)
with torch.no_grad():
summary_ids = model.generate(
inputs["input_ids"],
attention_mask=inputs.get("attention_mask", None),
num_beams=NUM_BEAMS,
min_length=SUMMARY_MIN_LENGTH,
max_length=SUMMARY_MAX_LENGTH,
early_stopping=True,
no_repeat_ngram_size=3
)
final_summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
return final_summary
|