File size: 12,407 Bytes
23654e5
 
 
e6341fe
67d3f72
 
 
23654e5
 
e6341fe
 
23654e5
e6341fe
23654e5
 
 
 
e6341fe
 
 
 
 
 
 
23654e5
e6341fe
 
 
 
 
 
23654e5
 
 
 
 
 
 
 
 
e6341fe
 
 
 
23654e5
 
 
 
 
 
 
 
 
e6341fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23654e5
 
e6341fe
 
 
 
 
 
 
23654e5
e6341fe
 
 
 
 
 
1377fb1
 
23654e5
e6341fe
 
 
 
23654e5
e6341fe
 
 
 
 
1377fb1
 
 
 
 
 
 
 
 
e6341fe
 
1377fb1
e6341fe
 
 
1377fb1
 
e6341fe
 
 
23654e5
 
 
 
 
 
 
 
 
 
 
 
 
 
e6341fe
 
 
 
 
 
23654e5
 
 
 
 
 
e6341fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71797a4
 
 
e6341fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71797a4
 
 
e6341fe
 
 
 
 
23654e5
 
 
 
 
 
 
 
 
 
 
71797a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23654e5
e6341fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00aacad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6341fe
 
 
 
 
 
 
 
 
23654e5
 
 
 
 
 
 
 
 
e6341fe
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
"""
AI-powered submission analyzer using Hugging Face zero-shot classification.
This module provides free, offline classification without requiring API keys.
Supports both base models and fine-tuned models with LoRA.

Copyright (c) 2024-2025 Marcos Thadeu Queiroz Magalhães (thadillo@gmail.com)
Licensed under MIT License - See LICENSE file for details
"""

from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
import torch
import logging
import os

logger = logging.getLogger(__name__)

class SubmissionAnalyzer:
    def __init__(self, use_finetuned: bool = True):
        """
        Initialize the classification model.

        Args:
            use_finetuned: Whether to check for and use fine-tuned models (default: True)
        """
        self.classifier = None
        self.model = None
        self.tokenizer = None
        self.use_finetuned = use_finetuned
        self.model_type = 'base'  # 'base' or 'finetuned'
        self.active_run_id = None

        self.categories = [
            'Vision',
            'Problem',
            'Objectives',
            'Directives',
            'Values',
            'Actions'
        ]

        self.label2id = {label: idx for idx, label in enumerate(self.categories)}
        self.id2label = {idx: label for idx, label in enumerate(self.categories)}

        # Category descriptions for better zero-shot classification
        self.category_descriptions = {
            'Vision': 'future aspirations, desired outcomes, what success looks like',
            'Problem': 'current issues, frustrations, causes of problems',
            'Objectives': 'specific goals to achieve',
            'Directives': 'restrictions or requirements for solution design',
            'Values': 'principles or restrictions for setting objectives',
            'Actions': 'concrete steps, interventions, or activities to implement'
        }

    def _check_for_finetuned_model(self):
        """Check if a fine-tuned model is active in the database"""
        if not self.use_finetuned:
            return None

        try:
            from app.models.models import FineTuningRun
            from app import db

            active_run = db.session.query(FineTuningRun).filter_by(is_active_model=True).first()

            if active_run:
                models_dir = os.getenv('MODELS_DIR', '/data/models/finetuned')
                model_path = os.path.join(models_dir, f'run_{active_run.id}')

                if os.path.exists(model_path):
                    logger.info(f"Found active fine-tuned model: run_{active_run.id}")
                    return model_path
                else:
                    logger.warning(f"Active model path not found: {model_path}")

        except Exception as e:
            logger.warning(f"Could not check for fine-tuned model: {e}")

        return None

    def _load_model(self):
        """Lazy load the model only when needed."""
        if self.classifier is not None or self.model is not None:
            return  # Already loaded

        # Check for fine-tuned model first
        finetuned_path = self._check_for_finetuned_model()

        if finetuned_path:
            try:
                logger.info(f"Loading fine-tuned model from {finetuned_path}")
                self.tokenizer = AutoTokenizer.from_pretrained(finetuned_path)
                self.model = AutoModelForSequenceClassification.from_pretrained(
                    finetuned_path,
                    num_labels=len(self.categories),
                    id2label=self.id2label,
                    label2id=self.label2id,
                    ignore_mismatched_sizes=True
                )
                self.model.eval()
                self.model_type = 'finetuned'
                logger.info("Fine-tuned model loaded successfully!")
                return
            except Exception as e:
                logger.error(f"Error loading fine-tuned model: {e}")
                logger.info("Falling back to base model")

        # Load base zero-shot model
        try:
            # Get selected zero-shot model from settings
            from app.models.models import Settings
            from app.fine_tuning.model_presets import get_model_preset
            
            zero_shot_model_key = Settings.get_setting('zero_shot_model', 'bart-large-mnli')
            model_preset = get_model_preset(zero_shot_model_key)
            zero_shot_model_id = model_preset['model_id']
            
            logger.info(f"Loading zero-shot classification model: {zero_shot_model_id}...")
            self.classifier = pipeline(
                "zero-shot-classification",
                model=zero_shot_model_id,
                device=-1  # Use CPU (-1), change to 0 for GPU
            )
            self.model_type = 'base'
            self.zero_shot_model_key = zero_shot_model_key
            logger.info(f"Zero-shot model loaded successfully: {model_preset['name']}!")
        except Exception as e:
            logger.error(f"Error loading model: {e}")
            raise

    def analyze(self, message):
        """
        Classify a submission message into one of the predefined categories.

        Args:
            message (str): The submission message to classify

        Returns:
            str: The predicted category
        """
        self._load_model()

        try:
            if self.model_type == 'finetuned':
                # Use fine-tuned model
                return self._classify_with_finetuned(message)
            else:
                # Use base zero-shot model
                return self._classify_with_zeroshot(message)

        except Exception as e:
            logger.error(f"Error analyzing message: {e}")
            # Fallback to Problem category if analysis fails
            return 'Problem'

    def _classify_with_finetuned(self, message):
        """Classify using fine-tuned model"""
        # Tokenize
        inputs = self.tokenizer(
            message,
            truncation=True,
            padding='max_length',
            max_length=128,
            return_tensors='pt'
        )

        # Predict
        with torch.no_grad():
            outputs = self.model(**inputs)
            predictions = torch.softmax(outputs.logits, dim=1)
            predicted_class = torch.argmax(predictions, dim=1).item()
            confidence = predictions[0][predicted_class].item()

        category = self.id2label[predicted_class]
        
        # Store confidence for later retrieval
        self._last_confidence = confidence

        logger.info(f"Fine-tuned model classified as: {category} (confidence: {confidence:.2f})")

        return category

    def _classify_with_zeroshot(self, message):
        """Classify using zero-shot base model"""
        # Use category descriptions as labels for better accuracy
        candidate_labels = [
            f"{cat}: {self.category_descriptions[cat]}"
            for cat in self.categories
        ]

        # Run classification
        result = self.classifier(
            message,
            candidate_labels,
            multi_label=False
        )

        # Extract the category name from the label
        top_label = result['labels'][0]
        category = top_label.split(':')[0]
        
        # Store confidence for later retrieval
        self._last_confidence = result['scores'][0]

        logger.info(f"Zero-shot model classified as: {category} (confidence: {result['scores'][0]:.2f})")

        return category

    def analyze_batch(self, messages):
        """
        Classify multiple messages at once.

        Args:
            messages (list): List of submission messages

        Returns:
            list: List of predicted categories
        """
        return [self.analyze(msg) for msg in messages]
    
    def analyze_with_sentences(self, submission_text: str):
        """
        Analyze submission at sentence level.
        
        Args:
            submission_text: Full submission text
            
        Returns:
            List[Dict]: List of {text: str, category: str, confidence: float}
        """
        from app.utils.text_processor import TextProcessor
        
        # Segment into sentences
        sentences = TextProcessor.segment_and_clean(submission_text)
        
        # Classify each sentence
        results = []
        for sentence in sentences:
            try:
                category = self.analyze(sentence)
                
                # Get confidence if available
                confidence = self._get_last_confidence() if hasattr(self, '_last_confidence') else None
                
                results.append({
                    'text': sentence,
                    'category': category,
                    'confidence': confidence
                })
                
                logger.info(f"Sentence classified: '{sentence[:50]}...' -> {category}")
            except Exception as e:
                logger.error(f"Error analyzing sentence '{sentence[:50]}...': {e}")
                # Skip problematic sentences
                continue
        
        return results
    
    def _get_last_confidence(self):
        """Get last prediction confidence (if available)"""
        return getattr(self, '_last_confidence', None)

    def get_model_info(self):
        """
        Get information about the currently loaded model.

        Returns:
            Dict with model information
        """
        self._load_model()

        info = {
            'model_type': self.model_type,
            'categories': self.categories
        }

        if self.model_type == 'finetuned':
            info['active_run_id'] = self.active_run_id
            info['model_loaded'] = self.model is not None
        else:
            info['base_model'] = 'facebook/bart-large-mnli'
            info['model_loaded'] = self.classifier is not None

        return info

    def analyze_sentences(self, sentences: list) -> list:
        """
        Analyze multiple sentences and return their categories with confidence scores.

        Args:
            sentences: List of sentence strings

        Returns:
            List of dicts with keys: 'text', 'category', 'confidence'
        """
        self._load_model()

        results = []
        for sentence in sentences:
            try:
                category = self.analyze(sentence)
                # For now, confidence is not available from all models
                # Could be extended to return confidence from fine-tuned models
                results.append({
                    'text': sentence,
                    'category': category,
                    'confidence': None
                })
            except Exception as e:
                logger.error(f"Error analyzing sentence '{sentence[:50]}...': {e}")
                results.append({
                    'text': sentence,
                    'category': 'Problem',  # Fallback
                    'confidence': None
                })

        return results

    def analyze_with_sentences(self, text: str) -> list:
        """
        Segment text into sentences and analyze each one.

        Args:
            text: Full text to segment and analyze

        Returns:
            List of dicts with keys: 'text', 'category', 'confidence'
        """
        from app.sentence_segmenter import SentenceSegmenter

        # Segment text into sentences
        segmenter = SentenceSegmenter()
        sentences = segmenter.segment(text)

        # Analyze each sentence
        return self.analyze_sentences(sentences)

    def reload_model(self):
        """Force reload the model (useful after deploying a new fine-tuned model)"""
        self.classifier = None
        self.model = None
        self.tokenizer = None
        self.model_type = 'base'
        self.active_run_id = None
        logger.info("Model cache cleared, will reload on next analysis")

# Global analyzer instance
_analyzer = None

def get_analyzer():
    """Get or create the global analyzer instance."""
    global _analyzer
    if _analyzer is None:
        _analyzer = SubmissionAnalyzer()
    return _analyzer

def reload_analyzer():
    """Force reload the analyzer (useful after model deployment)"""
    global _analyzer
    if _analyzer is not None:
        _analyzer.reload_model()
    logger.info("Analyzer reloaded")