File size: 7,567 Bytes
39028c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
"""
Model architectures and configurations for summarization
"""

import logging
from typing import Optional, List
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForSeq2SeqLM,
    AutoModelForCausalLM,
    pipeline
)


logger = logging.getLogger(__name__)


class SummarizationModelLoader:
    """Load and manage pre-trained summarization models."""
    
    # Popular pre-trained models for summarization
    AVAILABLE_MODELS = {
        # Fast & Lightweight models (RECOMMENDED)
        't5-small': 'google-t5/t5-small',
        't5-base': 'google-t5/t5-base',
        't5-large': 'google-t5/t5-large',
        'bart-base': 'facebook/bart-base',
        'bart-large-cnn': 'facebook/bart-large-cnn',
        'pegasus-arxiv': 'google/pegasus-arxiv',
        'pegasus-pubmed': 'google/pegasus-pubmed',
        'led': 'allenai/led-base-16384',  # For long documents
        
        # Multilingual models (supports 50+ languages)
        'mbart-50': 'facebook/mbart-large-50',
        'mbart-50-small': 'facebook/mbart-large-50-small',  # FASTEST - recommended for speed
        'mt5-small': 'google/mt5-small',  # Multilingual T5 (100+ languages)
        'mt5-base': 'google/mt5-base',
    }
    
    # Supported languages for multilingual models
    SUPPORTED_LANGUAGES = {
        'english': 'en_XX',
        'spanish': 'es_XX',
        'french': 'fr_XX',
        'german': 'de_DE',
        'italian': 'it_IT',
        'portuguese': 'pt_XX',
        'chinese': 'zh_CN',
        'japanese': 'ja_XX',
        'korean': 'ko_KR',
        'arabic': 'ar_AR',
        'hindi': 'hi_IN',
        'russian': 'ru_RU',
        'turkish': 'tr_TR',
        'vietnamese': 'vi_VN',
        'thai': 'th_TH',
    }
    
    def __init__(self, model_name: str = 't5-small', device: Optional[str] = None, language: str = 'english'):
        """
        Initialize model loader.
        
        Args:
            model_name: Name of the model to load (default: t5-small for speed)
            device: Device to load model on ('cpu' or 'cuda')
            language: Language for multilingual models
        """
        self.model_name = model_name
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.language = language
        self.language_code = self.SUPPORTED_LANGUAGES.get(language.lower(), 'en_XX')
        self.model = None
        self.tokenizer = None
        
        logger.info(f"Using device: {self.device}")
        logger.info(f"Language: {language} (code: {self.language_code})")
    
    def load_model(self, model_path: Optional[str] = None) -> tuple:
        """
        Load model and tokenizer.
        
        Args:
            model_path: Path to local model or HuggingFace model ID
            
        Returns:
            Tuple of (model, tokenizer)
        """
        path = model_path or self.AVAILABLE_MODELS.get(self.model_name, self.model_name)
        
        try:
            logger.info(f"Loading tokenizer from {path}")
            self.tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True)  # Fast tokenizer
            
            # Set language for multilingual models
            if 'mbart' in path.lower() or 'mt5' in path.lower():
                self.tokenizer.src_lang = self.language_code
            
            logger.info(f"Loading model from {path}")
            self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
            self.model.to(self.device)
            self.model.eval()
            
            logger.info(f"Model loaded successfully on {self.device}")
            return self.model, self.tokenizer
        
        except Exception as e:
            logger.error(f"Error loading model: {str(e)}")
            raise
    
    def get_model_info(self) -> dict:
        """Get information about loaded model."""
        if self.model is None:
            return {"status": "Model not loaded"}
        
        return {
            "model_name": self.model_name,
            "device": self.device,
            "parameters": sum(p.numel() for p in self.model.parameters()),
            "trainable_parameters": sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        }


class IntentClassifier:
    """Classify user intent for summarization."""
    
    INTENT_TYPES = {
        'technical_overview': 'high-level technical summary',
        'detailed_analysis': 'comprehensive technical analysis',
        'methodology': 'focus on methods and approach',
        'results': 'focus on results and findings',
        'conclusion': 'focus on conclusions and implications',
        'abstract': 'paper abstract-like summary',
    }
    
    def __init__(self):
        """Initialize intent classifier."""
        self.intent_prompts = {}
        for intent_key, intent_desc in self.INTENT_TYPES.items():
            self.intent_prompts[intent_key] = f"Provide {intent_desc}"
    
    def classify_intent(self, user_input: str) -> str:
        """
        Classify user intent from input text.
        
        Args:
            user_input: User's intent description
            
        Returns:
            Classified intent type
        """
        user_input_lower = user_input.lower()
        
        # Simple keyword matching (can be enhanced with ML model)
        for intent_key in self.INTENT_TYPES.keys():
            if intent_key.replace('_', ' ') in user_input_lower:
                return intent_key
        
        # Default to technical_overview
        return 'technical_overview'
    
    def get_prompt_for_intent(self, intent: str) -> str:
        """Get customized prompt for specific intent."""
        return self.intent_prompts.get(intent, self.intent_prompts['technical_overview'])


class ContextPreserver:
    """Preserve important context during summarization."""
    
    def __init__(self):
        """Initialize context preserver."""
        self.important_patterns = {
            'method': r'(?:method|approach|technique|algorithm)',
            'metric': r'(?:metric|accuracy|precision|recall|f1|score)',
            'dataset': r'(?:dataset|corpus|benchmark)',
            'baseline': r'(?:baseline|state-of-the-art|sota)',
        }
    
    def extract_important_content(self, text: str) -> List[str]:
        """
        Extract important content that should be preserved.
        
        Args:
            text: Document text
            
        Returns:
            List of important content snippets
        """
        important_snippets = []
        sentences = text.split('.')
        
        for sentence in sentences:
            for pattern in self.important_patterns.values():
                if len(sentence) > 20:  # Skip very short sentences
                    important_snippets.append(sentence.strip())
                    break
        
        return important_snippets
    
    def weight_content(self, sentences: List[str]) -> List[float]:
        """
        Assign importance weights to sentences.
        
        Args:
            sentences: List of sentences
            
        Returns:
            List of importance weights
        """
        weights = []
        for sentence in sentences:
            weight = 1.0
            
            # Boost weight for important content
            for pattern in self.important_patterns.values():
                if any(word in sentence.lower() for word in pattern.split('|')):
                    weight += 0.5
            
            weights.append(min(weight, 2.0))  # Cap at 2.0
        
        return weights