summarizer_space / summarizer.py
student2222333051's picture
Update summarizer.py
5bed521 verified
# summarizer.py
import os
import math
import torch
from transformers import BartTokenizer, BartForConditionalGeneration
# Конфигурация: fine-tuned модель атауы немесе default
MODEL_NAME = os.getenv("FINE_TUNED_MODEL", "facebook/bart-large-cnn")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Инициализация (бір рет)
tokenizer = BartTokenizer.from_pretrained(MODEL_NAME)
model = BartForConditionalGeneration.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()
# Параметрлер
MAX_INPUT_LENGTH = 1024
SUMMARY_MIN_LENGTH = 40
SUMMARY_MAX_LENGTH = 200
NUM_BEAMS = 4
def chunk_text(text: str, max_tokens: int = MAX_INPUT_LENGTH, overlap: int = 50):
"""
Ұзын мәтінді токендер бойынша бөліп қайтару. overlap — әр кусок арасында қайталанатын токен саны.
"""
inputs = tokenizer(text, return_tensors="pt", truncation=False)
input_ids = inputs["input_ids"][0].tolist()
chunks = []
start = 0
while start < len(input_ids):
end = start + max_tokens
chunk_ids = input_ids[start:end]
chunk_text = tokenizer.decode(chunk_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
chunks.append(chunk_text)
if end >= len(input_ids):
break
start = end - overlap
return chunks
def generate_summary(text: str) -> str:
"""
Егер мәтін MAX_INPUT_LENGTH-тен ұзын болса — бөліп, әр бөліктің summary алып,
содан кейін қысқа unified summary қайтару.
"""
text = text.strip()
if not text:
return ""
# Егер қысқа — тікелей summary
tokens = tokenizer(text, max_length=1, truncation=False)
# Қарапайым жүктеме: егер мәтін қысқа — бір шақыру
if len(tokenizer.encode(text)) <= MAX_INPUT_LENGTH:
inputs = tokenizer([text], max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors="pt").to(DEVICE)
with torch.no_grad():
summary_ids = model.generate(
inputs["input_ids"],
attention_mask=inputs.get("attention_mask", None),
num_beams=NUM_BEAMS,
min_length=SUMMARY_MIN_LENGTH,
max_length=SUMMARY_MAX_LENGTH,
early_stopping=True,
no_repeat_ngram_size=3
)
return tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
# Ұзын мәтін: бөліп, әр бөлімнің summary алып, содан кейін агрегаттау
chunks = chunk_text(text, max_tokens=MAX_INPUT_LENGTH, overlap=64)
partial_summaries = []
for chunk in chunks:
inputs = tokenizer([chunk], max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors="pt").to(DEVICE)
with torch.no_grad():
summary_ids = model.generate(
inputs["input_ids"],
attention_mask=inputs.get("attention_mask", None),
num_beams=NUM_BEAMS,
min_length=SUMMARY_MIN_LENGTH // 2,
max_length=SUMMARY_MAX_LENGTH,
early_stopping=True,
no_repeat_ngram_size=3
)
s = tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
partial_summaries.append(s)
# Біріктіру: partial_summaries-тан соңғы қысқаша summary жасау
combined = "\n\n".join(partial_summaries)
# Егер combined тым ұзын болса — қысқаша summary
inputs = tokenizer([combined], max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors="pt").to(DEVICE)
with torch.no_grad():
summary_ids = model.generate(
inputs["input_ids"],
attention_mask=inputs.get("attention_mask", None),
num_beams=NUM_BEAMS,
min_length=SUMMARY_MIN_LENGTH,
max_length=SUMMARY_MAX_LENGTH,
early_stopping=True,
no_repeat_ngram_size=3
)
final_summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
return final_summary