|
|
|
|
|
import os |
|
|
import math |
|
|
import torch |
|
|
from transformers import BartTokenizer, BartForConditionalGeneration |
|
|
|
|
|
|
|
|
MODEL_NAME = os.getenv("FINE_TUNED_MODEL", "facebook/bart-large-cnn") |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
tokenizer = BartTokenizer.from_pretrained(MODEL_NAME) |
|
|
model = BartForConditionalGeneration.from_pretrained(MODEL_NAME).to(DEVICE) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
MAX_INPUT_LENGTH = 1024 |
|
|
SUMMARY_MIN_LENGTH = 40 |
|
|
SUMMARY_MAX_LENGTH = 200 |
|
|
NUM_BEAMS = 4 |
|
|
|
|
|
def chunk_text(text: str, max_tokens: int = MAX_INPUT_LENGTH, overlap: int = 50): |
|
|
""" |
|
|
Ұзын мәтінді токендер бойынша бөліп қайтару. overlap — әр кусок арасында қайталанатын токен саны. |
|
|
""" |
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=False) |
|
|
input_ids = inputs["input_ids"][0].tolist() |
|
|
chunks = [] |
|
|
start = 0 |
|
|
while start < len(input_ids): |
|
|
end = start + max_tokens |
|
|
chunk_ids = input_ids[start:end] |
|
|
chunk_text = tokenizer.decode(chunk_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) |
|
|
chunks.append(chunk_text) |
|
|
if end >= len(input_ids): |
|
|
break |
|
|
start = end - overlap |
|
|
return chunks |
|
|
|
|
|
def generate_summary(text: str) -> str: |
|
|
""" |
|
|
Егер мәтін MAX_INPUT_LENGTH-тен ұзын болса — бөліп, әр бөліктің summary алып, |
|
|
содан кейін қысқа unified summary қайтару. |
|
|
""" |
|
|
text = text.strip() |
|
|
if not text: |
|
|
return "" |
|
|
|
|
|
|
|
|
tokens = tokenizer(text, max_length=1, truncation=False) |
|
|
|
|
|
if len(tokenizer.encode(text)) <= MAX_INPUT_LENGTH: |
|
|
inputs = tokenizer([text], max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors="pt").to(DEVICE) |
|
|
with torch.no_grad(): |
|
|
summary_ids = model.generate( |
|
|
inputs["input_ids"], |
|
|
attention_mask=inputs.get("attention_mask", None), |
|
|
num_beams=NUM_BEAMS, |
|
|
min_length=SUMMARY_MIN_LENGTH, |
|
|
max_length=SUMMARY_MAX_LENGTH, |
|
|
early_stopping=True, |
|
|
no_repeat_ngram_size=3 |
|
|
) |
|
|
return tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) |
|
|
|
|
|
|
|
|
chunks = chunk_text(text, max_tokens=MAX_INPUT_LENGTH, overlap=64) |
|
|
partial_summaries = [] |
|
|
for chunk in chunks: |
|
|
inputs = tokenizer([chunk], max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors="pt").to(DEVICE) |
|
|
with torch.no_grad(): |
|
|
summary_ids = model.generate( |
|
|
inputs["input_ids"], |
|
|
attention_mask=inputs.get("attention_mask", None), |
|
|
num_beams=NUM_BEAMS, |
|
|
min_length=SUMMARY_MIN_LENGTH // 2, |
|
|
max_length=SUMMARY_MAX_LENGTH, |
|
|
early_stopping=True, |
|
|
no_repeat_ngram_size=3 |
|
|
) |
|
|
s = tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) |
|
|
partial_summaries.append(s) |
|
|
|
|
|
|
|
|
combined = "\n\n".join(partial_summaries) |
|
|
|
|
|
inputs = tokenizer([combined], max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors="pt").to(DEVICE) |
|
|
with torch.no_grad(): |
|
|
summary_ids = model.generate( |
|
|
inputs["input_ids"], |
|
|
attention_mask=inputs.get("attention_mask", None), |
|
|
num_beams=NUM_BEAMS, |
|
|
min_length=SUMMARY_MIN_LENGTH, |
|
|
max_length=SUMMARY_MAX_LENGTH, |
|
|
early_stopping=True, |
|
|
no_repeat_ngram_size=3 |
|
|
) |
|
|
final_summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) |
|
|
return final_summary |
|
|
|