File size: 5,285 Bytes
ca8ebf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import re
from typing import Optional

class TaskSummarizer:
    def __init__(self, model_name="cointegrated/rut5-base-absum"):
        """
        Инициализация модели для суммаризации
        """
        self.model_name = model_name
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = None
        self.model = None
        
        print(f"🔄 Загрузка модели {model_name}...")
        print(f"📱 Устройство: {self.device}")
        
        try:
            # Загружаем токенизатор и модель
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
            self.model.to(self.device)
            self.model.eval()  # Режим оценки (не обучения)
            print("✅ Модель успешно загружена!")
        except Exception as e:
            print(f"❌ Ошибка при загрузке модели: {e}")
            print("💡 Попробуйте выполнить: pip install --upgrade transformers torch")
            raise
    
    def summarize(self, text: str, max_length: int = 50, min_length: int = 10) -> str:
        """
        Создает краткую суммаризацию текста задачи
        
        Args:
            text: Полный текст задачи
            max_length: Максимальная длина суммаризации
            min_length: Минимальная длина суммаризации
            
        Returns:
            Краткое описание задачи
        """
        if not text or len(text) < 20:
            return text
        
        try:
            # Очищаем текст от лишних символов
            text = self._clean_text(text)
            
            # Токенизируем входной текст
            inputs = self.tokenizer(
                text, 
                max_length=512, 
                truncation=True, 
                return_tensors="pt"
            ).to(self.device)
            
            # Генерируем суммаризацию
            with torch.no_grad():  # Отключаем вычисление градиентов для экономии памяти
                summary_ids = self.model.generate(
                    inputs.input_ids,
                    max_length=max_length,
                    min_length=min_length,
                    num_beams=4,  # Поиск с лучом для лучшего качества
                    length_penalty=2.0,  # Штраф за длину
                    early_stopping=True,
                    no_repeat_ngram_size=3  # Избегаем повторений
                )
            
            # Декодируем результат
            summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
            
            # Постобработка
            summary = self._postprocess_summary(summary)
            
            return summary
            
        except Exception as e:
            print(f"⚠️ Ошибка при суммаризации: {e}")
            # Возвращаем первые 100 символов как запасной вариант
            return text[:100] + "..."
    
    def _clean_text(self, text: str) -> str:
        """Очищает текст от лишних символов"""
        # Удаляем номер задачи в начале (если есть)
        text = re.sub(r'^\d+\.\s*', '', text)
        
        # Удаляем информацию об ответственном и сроке
        text = re.sub(r'Отв\.:.*?Срок\s*-\s*\d{2}\.\d{2}\.\d{4}', '', text)
        text = re.sub(r'Отв\.:.*$', '', text, flags=re.MULTILINE)
        text = re.sub(r'Срок\s*-\s*\d{2}\.\d{2}\.\d{4}', '', text)
        
        # Удаляем лишние пробелы
        text = re.sub(r'\s+', ' ', text)
        
        return text.strip()
    
    def _postprocess_summary(self, summary: str) -> str:
        """Постобработка сгенерированной суммаризации"""
        # Убираем лишние пробелы
        summary = re.sub(r'\s+', ' ', summary)
        
        # Убираем точку в конце, если её нет
        if summary and not summary.endswith(('.', '!', '?')):
            summary += '.'
        
        # Делаем первую букву заглавной
        if summary:
            summary = summary[0].upper() + summary[1:]
        
        return summary
    
    def summarize_batch(self, texts, max_length=50, min_length=10):
        """
        Суммаризация нескольких текстов (для эффективности)
        """
        summaries = []
        for text in texts:
            summary = self.summarize(text, max_length, min_length)
            summaries.append(summary)
        return summaries