File size: 4,337 Bytes
5bed521
 
 
 
6de6b50
e169923
5bed521
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44f67db
486dbd6
5bed521
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486dbd6
5bed521
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486dbd6
5bed521
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# summarizer.py
import os
import math
import torch
from transformers import BartTokenizer, BartForConditionalGeneration

# Конфигурация: fine-tuned модель атауы немесе default
MODEL_NAME = os.getenv("FINE_TUNED_MODEL", "facebook/bart-large-cnn")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Инициализация (бір рет)
tokenizer = BartTokenizer.from_pretrained(MODEL_NAME)
model = BartForConditionalGeneration.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()

# Параметрлер
MAX_INPUT_LENGTH = 1024
SUMMARY_MIN_LENGTH = 40
SUMMARY_MAX_LENGTH = 200
NUM_BEAMS = 4

def chunk_text(text: str, max_tokens: int = MAX_INPUT_LENGTH, overlap: int = 50):
    """
    Ұзын мәтінді токендер бойынша бөліп қайтару. overlap — әр кусок арасында қайталанатын токен саны.
    """
    inputs = tokenizer(text, return_tensors="pt", truncation=False)
    input_ids = inputs["input_ids"][0].tolist()
    chunks = []
    start = 0
    while start < len(input_ids):
        end = start + max_tokens
        chunk_ids = input_ids[start:end]
        chunk_text = tokenizer.decode(chunk_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        chunks.append(chunk_text)
        if end >= len(input_ids):
            break
        start = end - overlap
    return chunks

def generate_summary(text: str) -> str:
    """
    Егер мәтін MAX_INPUT_LENGTH-тен ұзын болса — бөліп, әр бөліктің summary алып,
    содан кейін қысқа unified summary қайтару.
    """
    text = text.strip()
    if not text:
        return ""

    # Егер қысқа — тікелей summary
    tokens = tokenizer(text, max_length=1, truncation=False)
    # Қарапайым жүктеме: егер мәтін қысқа — бір шақыру
    if len(tokenizer.encode(text)) <= MAX_INPUT_LENGTH:
        inputs = tokenizer([text], max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors="pt").to(DEVICE)
        with torch.no_grad():
            summary_ids = model.generate(
                inputs["input_ids"],
                attention_mask=inputs.get("attention_mask", None),
                num_beams=NUM_BEAMS,
                min_length=SUMMARY_MIN_LENGTH,
                max_length=SUMMARY_MAX_LENGTH,
                early_stopping=True,
                no_repeat_ngram_size=3
            )
        return tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)

    # Ұзын мәтін: бөліп, әр бөлімнің summary алып, содан кейін агрегаттау
    chunks = chunk_text(text, max_tokens=MAX_INPUT_LENGTH, overlap=64)
    partial_summaries = []
    for chunk in chunks:
        inputs = tokenizer([chunk], max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors="pt").to(DEVICE)
        with torch.no_grad():
            summary_ids = model.generate(
                inputs["input_ids"],
                attention_mask=inputs.get("attention_mask", None),
                num_beams=NUM_BEAMS,
                min_length=SUMMARY_MIN_LENGTH // 2,
                max_length=SUMMARY_MAX_LENGTH,
                early_stopping=True,
                no_repeat_ngram_size=3
            )
        s = tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
        partial_summaries.append(s)

    # Біріктіру: partial_summaries-тан соңғы қысқаша summary жасау
    combined = "\n\n".join(partial_summaries)
    # Егер combined тым ұзын болса — қысқаша summary
    inputs = tokenizer([combined], max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        summary_ids = model.generate(
            inputs["input_ids"],
            attention_mask=inputs.get("attention_mask", None),
            num_beams=NUM_BEAMS,
            min_length=SUMMARY_MIN_LENGTH,
            max_length=SUMMARY_MAX_LENGTH,
            early_stopping=True,
            no_repeat_ngram_size=3
        )
    final_summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
    return final_summary