Spaces:
Sleeping
Sleeping
| """ | |
| AI-powered submission analyzer using Hugging Face zero-shot classification. | |
| This module provides free, offline classification without requiring API keys. | |
| Supports both base models and fine-tuned models with LoRA. | |
| Copyright (c) 2024-2025 Marcos Thadeu Queiroz Magalhães (thadillo@gmail.com) | |
| Licensed under MIT License - See LICENSE file for details | |
| """ | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import logging | |
| import os | |
| logger = logging.getLogger(__name__) | |
| class SubmissionAnalyzer: | |
| def __init__(self, use_finetuned: bool = True): | |
| """ | |
| Initialize the classification model. | |
| Args: | |
| use_finetuned: Whether to check for and use fine-tuned models (default: True) | |
| """ | |
| self.classifier = None | |
| self.model = None | |
| self.tokenizer = None | |
| self.use_finetuned = use_finetuned | |
| self.model_type = 'base' # 'base' or 'finetuned' | |
| self.active_run_id = None | |
| self.categories = [ | |
| 'Vision', | |
| 'Problem', | |
| 'Objectives', | |
| 'Directives', | |
| 'Values', | |
| 'Actions' | |
| ] | |
| self.label2id = {label: idx for idx, label in enumerate(self.categories)} | |
| self.id2label = {idx: label for idx, label in enumerate(self.categories)} | |
| # Category descriptions for better zero-shot classification | |
| self.category_descriptions = { | |
| 'Vision': 'future aspirations, desired outcomes, what success looks like', | |
| 'Problem': 'current issues, frustrations, causes of problems', | |
| 'Objectives': 'specific goals to achieve', | |
| 'Directives': 'restrictions or requirements for solution design', | |
| 'Values': 'principles or restrictions for setting objectives', | |
| 'Actions': 'concrete steps, interventions, or activities to implement' | |
| } | |
| def _check_for_finetuned_model(self): | |
| """Check if a fine-tuned model is active in the database""" | |
| if not self.use_finetuned: | |
| return None | |
| try: | |
| from app.models.models import FineTuningRun | |
| from app import db | |
| active_run = db.session.query(FineTuningRun).filter_by(is_active_model=True).first() | |
| if active_run: | |
| models_dir = os.getenv('MODELS_DIR', '/data/models/finetuned') | |
| model_path = os.path.join(models_dir, f'run_{active_run.id}') | |
| if os.path.exists(model_path): | |
| logger.info(f"Found active fine-tuned model: run_{active_run.id}") | |
| return model_path | |
| else: | |
| logger.warning(f"Active model path not found: {model_path}") | |
| except Exception as e: | |
| logger.warning(f"Could not check for fine-tuned model: {e}") | |
| return None | |
| def _load_model(self): | |
| """Lazy load the model only when needed.""" | |
| if self.classifier is not None or self.model is not None: | |
| return # Already loaded | |
| # Check for fine-tuned model first | |
| finetuned_path = self._check_for_finetuned_model() | |
| if finetuned_path: | |
| try: | |
| logger.info(f"Loading fine-tuned model from {finetuned_path}") | |
| self.tokenizer = AutoTokenizer.from_pretrained(finetuned_path) | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| finetuned_path, | |
| num_labels=len(self.categories), | |
| id2label=self.id2label, | |
| label2id=self.label2id, | |
| ignore_mismatched_sizes=True | |
| ) | |
| self.model.eval() | |
| self.model_type = 'finetuned' | |
| logger.info("Fine-tuned model loaded successfully!") | |
| return | |
| except Exception as e: | |
| logger.error(f"Error loading fine-tuned model: {e}") | |
| logger.info("Falling back to base model") | |
| # Load base zero-shot model | |
| try: | |
| # Get selected zero-shot model from settings | |
| from app.models.models import Settings | |
| from app.fine_tuning.model_presets import get_model_preset | |
| zero_shot_model_key = Settings.get_setting('zero_shot_model', 'bart-large-mnli') | |
| model_preset = get_model_preset(zero_shot_model_key) | |
| zero_shot_model_id = model_preset['model_id'] | |
| logger.info(f"Loading zero-shot classification model: {zero_shot_model_id}...") | |
| self.classifier = pipeline( | |
| "zero-shot-classification", | |
| model=zero_shot_model_id, | |
| device=-1 # Use CPU (-1), change to 0 for GPU | |
| ) | |
| self.model_type = 'base' | |
| self.zero_shot_model_key = zero_shot_model_key | |
| logger.info(f"Zero-shot model loaded successfully: {model_preset['name']}!") | |
| except Exception as e: | |
| logger.error(f"Error loading model: {e}") | |
| raise | |
| def analyze(self, message): | |
| """ | |
| Classify a submission message into one of the predefined categories. | |
| Args: | |
| message (str): The submission message to classify | |
| Returns: | |
| str: The predicted category | |
| """ | |
| self._load_model() | |
| try: | |
| if self.model_type == 'finetuned': | |
| # Use fine-tuned model | |
| return self._classify_with_finetuned(message) | |
| else: | |
| # Use base zero-shot model | |
| return self._classify_with_zeroshot(message) | |
| except Exception as e: | |
| logger.error(f"Error analyzing message: {e}") | |
| # Fallback to Problem category if analysis fails | |
| return 'Problem' | |
| def _classify_with_finetuned(self, message): | |
| """Classify using fine-tuned model""" | |
| # Tokenize | |
| inputs = self.tokenizer( | |
| message, | |
| truncation=True, | |
| padding='max_length', | |
| max_length=128, | |
| return_tensors='pt' | |
| ) | |
| # Predict | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| predictions = torch.softmax(outputs.logits, dim=1) | |
| predicted_class = torch.argmax(predictions, dim=1).item() | |
| confidence = predictions[0][predicted_class].item() | |
| category = self.id2label[predicted_class] | |
| # Store confidence for later retrieval | |
| self._last_confidence = confidence | |
| logger.info(f"Fine-tuned model classified as: {category} (confidence: {confidence:.2f})") | |
| return category | |
| def _classify_with_zeroshot(self, message): | |
| """Classify using zero-shot base model""" | |
| # Use category descriptions as labels for better accuracy | |
| candidate_labels = [ | |
| f"{cat}: {self.category_descriptions[cat]}" | |
| for cat in self.categories | |
| ] | |
| # Run classification | |
| result = self.classifier( | |
| message, | |
| candidate_labels, | |
| multi_label=False | |
| ) | |
| # Extract the category name from the label | |
| top_label = result['labels'][0] | |
| category = top_label.split(':')[0] | |
| # Store confidence for later retrieval | |
| self._last_confidence = result['scores'][0] | |
| logger.info(f"Zero-shot model classified as: {category} (confidence: {result['scores'][0]:.2f})") | |
| return category | |
| def analyze_batch(self, messages): | |
| """ | |
| Classify multiple messages at once. | |
| Args: | |
| messages (list): List of submission messages | |
| Returns: | |
| list: List of predicted categories | |
| """ | |
| return [self.analyze(msg) for msg in messages] | |
| def analyze_with_sentences(self, submission_text: str): | |
| """ | |
| Analyze submission at sentence level. | |
| Args: | |
| submission_text: Full submission text | |
| Returns: | |
| List[Dict]: List of {text: str, category: str, confidence: float} | |
| """ | |
| from app.utils.text_processor import TextProcessor | |
| # Segment into sentences | |
| sentences = TextProcessor.segment_and_clean(submission_text) | |
| # Classify each sentence | |
| results = [] | |
| for sentence in sentences: | |
| try: | |
| category = self.analyze(sentence) | |
| # Get confidence if available | |
| confidence = self._get_last_confidence() if hasattr(self, '_last_confidence') else None | |
| results.append({ | |
| 'text': sentence, | |
| 'category': category, | |
| 'confidence': confidence | |
| }) | |
| logger.info(f"Sentence classified: '{sentence[:50]}...' -> {category}") | |
| except Exception as e: | |
| logger.error(f"Error analyzing sentence '{sentence[:50]}...': {e}") | |
| # Skip problematic sentences | |
| continue | |
| return results | |
| def _get_last_confidence(self): | |
| """Get last prediction confidence (if available)""" | |
| return getattr(self, '_last_confidence', None) | |
| def get_model_info(self): | |
| """ | |
| Get information about the currently loaded model. | |
| Returns: | |
| Dict with model information | |
| """ | |
| self._load_model() | |
| info = { | |
| 'model_type': self.model_type, | |
| 'categories': self.categories | |
| } | |
| if self.model_type == 'finetuned': | |
| info['active_run_id'] = self.active_run_id | |
| info['model_loaded'] = self.model is not None | |
| else: | |
| info['base_model'] = 'facebook/bart-large-mnli' | |
| info['model_loaded'] = self.classifier is not None | |
| return info | |
| def analyze_sentences(self, sentences: list) -> list: | |
| """ | |
| Analyze multiple sentences and return their categories with confidence scores. | |
| Args: | |
| sentences: List of sentence strings | |
| Returns: | |
| List of dicts with keys: 'text', 'category', 'confidence' | |
| """ | |
| self._load_model() | |
| results = [] | |
| for sentence in sentences: | |
| try: | |
| category = self.analyze(sentence) | |
| # For now, confidence is not available from all models | |
| # Could be extended to return confidence from fine-tuned models | |
| results.append({ | |
| 'text': sentence, | |
| 'category': category, | |
| 'confidence': None | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error analyzing sentence '{sentence[:50]}...': {e}") | |
| results.append({ | |
| 'text': sentence, | |
| 'category': 'Problem', # Fallback | |
| 'confidence': None | |
| }) | |
| return results | |
| def analyze_with_sentences(self, text: str) -> list: | |
| """ | |
| Segment text into sentences and analyze each one. | |
| Args: | |
| text: Full text to segment and analyze | |
| Returns: | |
| List of dicts with keys: 'text', 'category', 'confidence' | |
| """ | |
| from app.sentence_segmenter import SentenceSegmenter | |
| # Segment text into sentences | |
| segmenter = SentenceSegmenter() | |
| sentences = segmenter.segment(text) | |
| # Analyze each sentence | |
| return self.analyze_sentences(sentences) | |
| def reload_model(self): | |
| """Force reload the model (useful after deploying a new fine-tuned model)""" | |
| self.classifier = None | |
| self.model = None | |
| self.tokenizer = None | |
| self.model_type = 'base' | |
| self.active_run_id = None | |
| logger.info("Model cache cleared, will reload on next analysis") | |
| # Global analyzer instance | |
| _analyzer = None | |
| def get_analyzer(): | |
| """Get or create the global analyzer instance.""" | |
| global _analyzer | |
| if _analyzer is None: | |
| _analyzer = SubmissionAnalyzer() | |
| return _analyzer | |
| def reload_analyzer(): | |
| """Force reload the analyzer (useful after model deployment)""" | |
| global _analyzer | |
| if _analyzer is not None: | |
| _analyzer.reload_model() | |
| logger.info("Analyzer reloaded") | |