|
|
|
|
|
from typing import Dict |
|
|
from typing import List |
|
|
from typing import Tuple |
|
|
from loguru import logger |
|
|
from typing import Optional |
|
|
from dataclasses import dataclass |
|
|
from config.threshold_config import Domain |
|
|
from models.model_manager import get_model_manager |
|
|
from config.threshold_config import interpolate_thresholds |
|
|
from config.threshold_config import get_threshold_for_domain |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DomainPrediction: |
|
|
""" |
|
|
Result of domain classification |
|
|
""" |
|
|
primary_domain : Domain |
|
|
secondary_domain : Optional[Domain] |
|
|
confidence : float |
|
|
domain_scores : Dict[str, float] |
|
|
|
|
|
|
|
|
class DomainClassifier: |
|
|
""" |
|
|
Classifies text into domains using primary model with different fallback model |
|
|
""" |
|
|
|
|
|
DOMAIN_LABELS = {Domain.ACADEMIC : ["academic writing", "research paper", "scholarly article", "thesis", "scientific report"], |
|
|
Domain.CREATIVE : ["creative writing", "fiction", "poetry", "story", "narrative"], |
|
|
Domain.AI_ML : ["machine learning", "artificial intelligence", "neural networks", "data science", "AI research"], |
|
|
Domain.SOFTWARE_DEV : ["software development", "programming", "coding", "software engineering", "web development"], |
|
|
Domain.TECHNICAL_DOC: ["technical documentation", "user manual", "API documentation", "technical guide", "installation guide"], |
|
|
Domain.ENGINEERING : ["engineering", "mechanical engineering", "electrical engineering", "design", "technical design"], |
|
|
Domain.SCIENCE : ["scientific research", "physics", "chemistry", "biology", "scientific study"], |
|
|
Domain.BUSINESS : ["business document", "corporate communication", "professional writing", "business report", "marketing"], |
|
|
Domain.JOURNALISM : ["news article", "journalism", "press release", "news report", "media"], |
|
|
Domain.SOCIAL_MEDIA : ["social media post", "blog post", "casual writing", "online content", "informal text"], |
|
|
Domain.BLOG_PERSONAL: ["personal blog", "personal writing", "lifestyle blog", "personal experience", "opinion piece"], |
|
|
Domain.LEGAL : ["legal document", "contract", "legal writing", "law", "judicial"], |
|
|
Domain.MEDICAL : ["medical document", "healthcare", "clinical", "medical report", "health"], |
|
|
Domain.MARKETING : ["marketing content", "advertising", "brand content", "promotional writing", "sales copy"], |
|
|
Domain.TUTORIAL : ["tutorial", "how-to guide", "instructional content", "step-by-step guide", "educational guide"], |
|
|
Domain.GENERAL : ["general content", "everyday writing", "common text", "standard writing", "normal text"] |
|
|
} |
|
|
|
|
|
|
|
|
def __init__(self): |
|
|
self.model_manager = get_model_manager() |
|
|
self.primary_classifier = None |
|
|
self.fallback_classifier = None |
|
|
self.is_initialized = False |
|
|
|
|
|
|
|
|
def initialize(self) -> bool: |
|
|
""" |
|
|
Initialize the domain classifier with primary and fallback models |
|
|
""" |
|
|
try: |
|
|
logger.info("Initializing domain classifier...") |
|
|
|
|
|
|
|
|
self.primary_classifier = self.model_manager.load_pipeline(model_name = "domain_classifier", |
|
|
task = "zero-shot-classification", |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
self.fallback_classifier = self.model_manager.load_pipeline(model_name = "domain_classifier_fallback", |
|
|
task = "zero-shot-classification", |
|
|
) |
|
|
logger.info("Fallback classifier loaded successfully") |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Could not load fallback classifier: {repr(e)}") |
|
|
self.fallback_classifier = None |
|
|
|
|
|
self.is_initialized = True |
|
|
logger.success("Domain classifier initialized successfully") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize domain classifier: {repr(e)}") |
|
|
return False |
|
|
|
|
|
|
|
|
def classify(self, text: str, top_k: int = 2, min_confidence: float = 0.3) -> DomainPrediction: |
|
|
""" |
|
|
Classify text into domain using primary model with fallback to different model |
|
|
|
|
|
Arguments: |
|
|
---------- |
|
|
text { str } : Input text |
|
|
|
|
|
top_k { int } : Number of top domains to consider |
|
|
|
|
|
min_confidence { float } : Minimum confidence threshold |
|
|
|
|
|
Returns: |
|
|
-------- |
|
|
{ DomainPrediction } : DomainPrediction object |
|
|
""" |
|
|
if not self.is_initialized: |
|
|
logger.warning("Domain classifier not initialized, initializing now...") |
|
|
|
|
|
if not self.initialize(): |
|
|
return self._get_default_prediction() |
|
|
|
|
|
try: |
|
|
|
|
|
primary_result = self._classify_with_model(text = text, |
|
|
classifier = self.primary_classifier, |
|
|
model_type = "primary", |
|
|
) |
|
|
|
|
|
|
|
|
if (primary_result.confidence >= min_confidence): |
|
|
return primary_result |
|
|
|
|
|
|
|
|
if self.fallback_classifier: |
|
|
logger.info("Primary classifier low confidence, trying fallback model...") |
|
|
|
|
|
fallback_result = self._classify_with_model(text = text, |
|
|
classifier = self.fallback_classifier, |
|
|
model_type = "fallback", |
|
|
) |
|
|
|
|
|
|
|
|
if (fallback_result.confidence > primary_result.confidence): |
|
|
return fallback_result |
|
|
|
|
|
|
|
|
return primary_result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in primary domain classification: {repr(e)}") |
|
|
|
|
|
|
|
|
if self.fallback_classifier: |
|
|
try: |
|
|
logger.info("Trying fallback classifier after primary failure...") |
|
|
return self._classify_with_model(text = text, |
|
|
classifier = self.fallback_classifier, |
|
|
model_type = "fallback", |
|
|
) |
|
|
|
|
|
except Exception as fallback_error: |
|
|
logger.error(f"Fallback classifier also failed: {repr(fallback_error)}") |
|
|
|
|
|
|
|
|
return self._get_default_prediction() |
|
|
|
|
|
|
|
|
def _classify_with_model(self, text: str, classifier, model_type: str) -> DomainPrediction: |
|
|
""" |
|
|
Classify using a specific model with interpolation for mixed domains |
|
|
|
|
|
Arguments: |
|
|
---------- |
|
|
text { str } : Input text |
|
|
|
|
|
classifier { object } : Classifier model |
|
|
|
|
|
model_type { str } : Type of model for logging |
|
|
|
|
|
Returns: |
|
|
-------- |
|
|
{ DomainPrediction } : DomainPrediction object |
|
|
""" |
|
|
|
|
|
words = text.split() |
|
|
if (len(words) > 500): |
|
|
text = ' '.join(words[:500]) |
|
|
|
|
|
|
|
|
all_labels = list() |
|
|
label_to_domain = dict() |
|
|
|
|
|
for domain, labels in self.DOMAIN_LABELS.items(): |
|
|
for label in labels: |
|
|
all_labels.append(label) |
|
|
label_to_domain[label] = domain |
|
|
|
|
|
|
|
|
result = classifier(text, |
|
|
candidate_labels = all_labels, |
|
|
multi_label = False, |
|
|
) |
|
|
|
|
|
|
|
|
domain_scores = dict() |
|
|
|
|
|
for label, score in zip(result['labels'], result['scores']): |
|
|
domain = label_to_domain[label] |
|
|
domain_key = domain.value |
|
|
|
|
|
if (domain_key not in domain_scores): |
|
|
domain_scores[domain_key] = list() |
|
|
|
|
|
domain_scores[domain_key].append(score) |
|
|
|
|
|
|
|
|
avg_domain_scores = {domain: sum(scores) / len(scores) for domain, scores in domain_scores.items()} |
|
|
|
|
|
|
|
|
sorted_domains = sorted(avg_domain_scores.items(), key = lambda x: x[1], reverse = True) |
|
|
|
|
|
|
|
|
primary_domain_str, primary_score = sorted_domains[0] |
|
|
primary_domain = Domain(primary_domain_str) |
|
|
|
|
|
secondary_domain = None |
|
|
secondary_score = 0.0 |
|
|
|
|
|
if ((len(sorted_domains) > 1) and (sorted_domains[1][1] >= 0.2)): |
|
|
secondary_domain = Domain(sorted_domains[1][0]) |
|
|
secondary_score = sorted_domains[1][1] |
|
|
|
|
|
|
|
|
should_interpolate = False |
|
|
interpolation_weight = 0.5 |
|
|
|
|
|
if (secondary_domain and (primary_score < 0.7) and (secondary_score > 0.3)): |
|
|
|
|
|
score_ratio = secondary_score / primary_score |
|
|
|
|
|
|
|
|
if (score_ratio > 0.6): |
|
|
should_interpolate = True |
|
|
interpolation_weight = primary_score / (primary_score + secondary_score) |
|
|
|
|
|
|
|
|
confidence = primary_score |
|
|
|
|
|
|
|
|
if (should_interpolate): |
|
|
|
|
|
confidence = (primary_score + secondary_score) / 2 * 0.8 |
|
|
logger.info(f"Mixed domain detected: {primary_domain.value} + {secondary_domain.value}, will use interpolated thresholds") |
|
|
|
|
|
|
|
|
elif ((primary_score < 0.5) and secondary_domain): |
|
|
|
|
|
confidence *= 0.8 |
|
|
|
|
|
logger.info(f"{model_type.capitalize()} model classified domain: {primary_domain.value} (confidence: {confidence:.2f})") |
|
|
|
|
|
return DomainPrediction(primary_domain = primary_domain, |
|
|
secondary_domain = secondary_domain, |
|
|
confidence = confidence, |
|
|
domain_scores = avg_domain_scores, |
|
|
) |
|
|
|
|
|
|
|
|
def _get_default_prediction(self) -> DomainPrediction: |
|
|
""" |
|
|
Get default prediction when classification fails |
|
|
""" |
|
|
return DomainPrediction(primary_domain = Domain.GENERAL, |
|
|
secondary_domain = None, |
|
|
confidence = 0.5, |
|
|
domain_scores = {domain.value: 1.0/len(Domain) for domain in Domain}, |
|
|
) |
|
|
|
|
|
|
|
|
def get_adaptive_thresholds(self, domain_prediction: DomainPrediction): |
|
|
""" |
|
|
Get adaptive thresholds based on domain prediction with intelligent interpolation |
|
|
|
|
|
Arguments: |
|
|
---------- |
|
|
domain_prediction : Domain prediction result |
|
|
|
|
|
Returns: |
|
|
-------- |
|
|
DomainThresholds object |
|
|
""" |
|
|
|
|
|
if ((domain_prediction.confidence > 0.7) and (not domain_prediction.secondary_domain)): |
|
|
return get_threshold_for_domain(domain_prediction.primary_domain) |
|
|
|
|
|
|
|
|
if domain_prediction.secondary_domain: |
|
|
|
|
|
primary_score = domain_prediction.domain_scores.get(domain_prediction.primary_domain.value, 0) |
|
|
secondary_score = domain_prediction.domain_scores.get(domain_prediction.secondary_domain.value, 0) |
|
|
|
|
|
if (primary_score + secondary_score > 0): |
|
|
weight1 = primary_score / (primary_score + secondary_score) |
|
|
|
|
|
else: |
|
|
weight1 = domain_prediction.confidence |
|
|
|
|
|
thresholds = interpolate_thresholds(domain1 = domain_prediction.primary_domain, |
|
|
domain2 = domain_prediction.secondary_domain, |
|
|
weight1 = weight1, |
|
|
) |
|
|
return thresholds |
|
|
|
|
|
|
|
|
if (domain_prediction.confidence < 0.6): |
|
|
thresholds = interpolate_thresholds(domain1 = domain_prediction.primary_domain, |
|
|
domain2 = Domain.GENERAL, |
|
|
weight1 = domain_prediction.confidence, |
|
|
) |
|
|
return thresholds |
|
|
|
|
|
|
|
|
return get_threshold_for_domain(domain_prediction.primary_domain) |
|
|
|
|
|
|
|
|
def cleanup(self): |
|
|
""" |
|
|
Clean up resources |
|
|
""" |
|
|
self.primary_classifier = None |
|
|
self.fallback_classifier = None |
|
|
self.is_initialized = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ["DomainClassifier", |
|
|
"DomainPrediction", |
|
|
] |