thadillo
Add copyright and attribution for Marcos Thadeu Queiroz Magalhães
67d3f72
raw
history blame
12.4 kB
"""
AI-powered submission analyzer using Hugging Face zero-shot classification.
This module provides free, offline classification without requiring API keys.
Supports both base models and fine-tuned models with LoRA.
Copyright (c) 2024-2025 Marcos Thadeu Queiroz Magalhães (thadillo@gmail.com)
Licensed under MIT License - See LICENSE file for details
"""
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
import torch
import logging
import os
logger = logging.getLogger(__name__)
class SubmissionAnalyzer:
def __init__(self, use_finetuned: bool = True):
"""
Initialize the classification model.
Args:
use_finetuned: Whether to check for and use fine-tuned models (default: True)
"""
self.classifier = None
self.model = None
self.tokenizer = None
self.use_finetuned = use_finetuned
self.model_type = 'base' # 'base' or 'finetuned'
self.active_run_id = None
self.categories = [
'Vision',
'Problem',
'Objectives',
'Directives',
'Values',
'Actions'
]
self.label2id = {label: idx for idx, label in enumerate(self.categories)}
self.id2label = {idx: label for idx, label in enumerate(self.categories)}
# Category descriptions for better zero-shot classification
self.category_descriptions = {
'Vision': 'future aspirations, desired outcomes, what success looks like',
'Problem': 'current issues, frustrations, causes of problems',
'Objectives': 'specific goals to achieve',
'Directives': 'restrictions or requirements for solution design',
'Values': 'principles or restrictions for setting objectives',
'Actions': 'concrete steps, interventions, or activities to implement'
}
def _check_for_finetuned_model(self):
"""Check if a fine-tuned model is active in the database"""
if not self.use_finetuned:
return None
try:
from app.models.models import FineTuningRun
from app import db
active_run = db.session.query(FineTuningRun).filter_by(is_active_model=True).first()
if active_run:
models_dir = os.getenv('MODELS_DIR', '/data/models/finetuned')
model_path = os.path.join(models_dir, f'run_{active_run.id}')
if os.path.exists(model_path):
logger.info(f"Found active fine-tuned model: run_{active_run.id}")
return model_path
else:
logger.warning(f"Active model path not found: {model_path}")
except Exception as e:
logger.warning(f"Could not check for fine-tuned model: {e}")
return None
def _load_model(self):
"""Lazy load the model only when needed."""
if self.classifier is not None or self.model is not None:
return # Already loaded
# Check for fine-tuned model first
finetuned_path = self._check_for_finetuned_model()
if finetuned_path:
try:
logger.info(f"Loading fine-tuned model from {finetuned_path}")
self.tokenizer = AutoTokenizer.from_pretrained(finetuned_path)
self.model = AutoModelForSequenceClassification.from_pretrained(
finetuned_path,
num_labels=len(self.categories),
id2label=self.id2label,
label2id=self.label2id,
ignore_mismatched_sizes=True
)
self.model.eval()
self.model_type = 'finetuned'
logger.info("Fine-tuned model loaded successfully!")
return
except Exception as e:
logger.error(f"Error loading fine-tuned model: {e}")
logger.info("Falling back to base model")
# Load base zero-shot model
try:
# Get selected zero-shot model from settings
from app.models.models import Settings
from app.fine_tuning.model_presets import get_model_preset
zero_shot_model_key = Settings.get_setting('zero_shot_model', 'bart-large-mnli')
model_preset = get_model_preset(zero_shot_model_key)
zero_shot_model_id = model_preset['model_id']
logger.info(f"Loading zero-shot classification model: {zero_shot_model_id}...")
self.classifier = pipeline(
"zero-shot-classification",
model=zero_shot_model_id,
device=-1 # Use CPU (-1), change to 0 for GPU
)
self.model_type = 'base'
self.zero_shot_model_key = zero_shot_model_key
logger.info(f"Zero-shot model loaded successfully: {model_preset['name']}!")
except Exception as e:
logger.error(f"Error loading model: {e}")
raise
def analyze(self, message):
"""
Classify a submission message into one of the predefined categories.
Args:
message (str): The submission message to classify
Returns:
str: The predicted category
"""
self._load_model()
try:
if self.model_type == 'finetuned':
# Use fine-tuned model
return self._classify_with_finetuned(message)
else:
# Use base zero-shot model
return self._classify_with_zeroshot(message)
except Exception as e:
logger.error(f"Error analyzing message: {e}")
# Fallback to Problem category if analysis fails
return 'Problem'
def _classify_with_finetuned(self, message):
"""Classify using fine-tuned model"""
# Tokenize
inputs = self.tokenizer(
message,
truncation=True,
padding='max_length',
max_length=128,
return_tensors='pt'
)
# Predict
with torch.no_grad():
outputs = self.model(**inputs)
predictions = torch.softmax(outputs.logits, dim=1)
predicted_class = torch.argmax(predictions, dim=1).item()
confidence = predictions[0][predicted_class].item()
category = self.id2label[predicted_class]
# Store confidence for later retrieval
self._last_confidence = confidence
logger.info(f"Fine-tuned model classified as: {category} (confidence: {confidence:.2f})")
return category
def _classify_with_zeroshot(self, message):
"""Classify using zero-shot base model"""
# Use category descriptions as labels for better accuracy
candidate_labels = [
f"{cat}: {self.category_descriptions[cat]}"
for cat in self.categories
]
# Run classification
result = self.classifier(
message,
candidate_labels,
multi_label=False
)
# Extract the category name from the label
top_label = result['labels'][0]
category = top_label.split(':')[0]
# Store confidence for later retrieval
self._last_confidence = result['scores'][0]
logger.info(f"Zero-shot model classified as: {category} (confidence: {result['scores'][0]:.2f})")
return category
def analyze_batch(self, messages):
"""
Classify multiple messages at once.
Args:
messages (list): List of submission messages
Returns:
list: List of predicted categories
"""
return [self.analyze(msg) for msg in messages]
def analyze_with_sentences(self, submission_text: str):
"""
Analyze submission at sentence level.
Args:
submission_text: Full submission text
Returns:
List[Dict]: List of {text: str, category: str, confidence: float}
"""
from app.utils.text_processor import TextProcessor
# Segment into sentences
sentences = TextProcessor.segment_and_clean(submission_text)
# Classify each sentence
results = []
for sentence in sentences:
try:
category = self.analyze(sentence)
# Get confidence if available
confidence = self._get_last_confidence() if hasattr(self, '_last_confidence') else None
results.append({
'text': sentence,
'category': category,
'confidence': confidence
})
logger.info(f"Sentence classified: '{sentence[:50]}...' -> {category}")
except Exception as e:
logger.error(f"Error analyzing sentence '{sentence[:50]}...': {e}")
# Skip problematic sentences
continue
return results
def _get_last_confidence(self):
"""Get last prediction confidence (if available)"""
return getattr(self, '_last_confidence', None)
def get_model_info(self):
"""
Get information about the currently loaded model.
Returns:
Dict with model information
"""
self._load_model()
info = {
'model_type': self.model_type,
'categories': self.categories
}
if self.model_type == 'finetuned':
info['active_run_id'] = self.active_run_id
info['model_loaded'] = self.model is not None
else:
info['base_model'] = 'facebook/bart-large-mnli'
info['model_loaded'] = self.classifier is not None
return info
def analyze_sentences(self, sentences: list) -> list:
"""
Analyze multiple sentences and return their categories with confidence scores.
Args:
sentences: List of sentence strings
Returns:
List of dicts with keys: 'text', 'category', 'confidence'
"""
self._load_model()
results = []
for sentence in sentences:
try:
category = self.analyze(sentence)
# For now, confidence is not available from all models
# Could be extended to return confidence from fine-tuned models
results.append({
'text': sentence,
'category': category,
'confidence': None
})
except Exception as e:
logger.error(f"Error analyzing sentence '{sentence[:50]}...': {e}")
results.append({
'text': sentence,
'category': 'Problem', # Fallback
'confidence': None
})
return results
def analyze_with_sentences(self, text: str) -> list:
"""
Segment text into sentences and analyze each one.
Args:
text: Full text to segment and analyze
Returns:
List of dicts with keys: 'text', 'category', 'confidence'
"""
from app.sentence_segmenter import SentenceSegmenter
# Segment text into sentences
segmenter = SentenceSegmenter()
sentences = segmenter.segment(text)
# Analyze each sentence
return self.analyze_sentences(sentences)
def reload_model(self):
"""Force reload the model (useful after deploying a new fine-tuned model)"""
self.classifier = None
self.model = None
self.tokenizer = None
self.model_type = 'base'
self.active_run_id = None
logger.info("Model cache cleared, will reload on next analysis")
# Global analyzer instance
_analyzer = None
def get_analyzer():
"""Get or create the global analyzer instance."""
global _analyzer
if _analyzer is None:
_analyzer = SubmissionAnalyzer()
return _analyzer
def reload_analyzer():
"""Force reload the analyzer (useful after model deployment)"""
global _analyzer
if _analyzer is not None:
_analyzer.reload_model()
logger.info("Analyzer reloaded")