Spaces:
Sleeping
Sleeping
| from typing import Dict, List | |
| import re | |
| from models.hate_speech_classifier import HateSpeechClassifier | |
| from models.language_detector import detect_language | |
| # Initialize classifier globally | |
| classifier = HateSpeechClassifier() | |
| def highlight_keywords(text: str, keywords: List[str]) -> List[str]: | |
| """Extract phrases containing keywords""" | |
| highlighted = [] | |
| text_lower = text.lower() | |
| for keyword in keywords: | |
| if keyword.lower() in text_lower: | |
| sentences = re.split(r'[।.!?]+', text) | |
| for sentence in sentences: | |
| if keyword.lower() in sentence.lower(): | |
| highlighted.append(sentence.strip()) | |
| break | |
| return highlighted[:5] | |
| async def analyze_content(text: str) -> Dict: | |
| """ | |
| Main analysis function that combines all models | |
| """ | |
| # Detect language | |
| language = detect_language(text) | |
| # Get results from all three methods | |
| custom_result = await classifier.classify_with_custom_model(text, language) | |
| # ✅ Pass language to pretrained model for translation support | |
| pretrained_result = await classifier.classify_with_pretrained_model(text, language) | |
| keyword_result = classifier.classify_with_keywords(text, language) | |
| # Enhanced ensemble decision with adaptive weights | |
| results = [] | |
| has_patterns = keyword_result.get("pattern_matches", 0) > 0 | |
| has_hate_keywords = keyword_result.get("hate_count", 0) > 0 | |
| if has_patterns or has_hate_keywords: | |
| custom_weight = 0.5 | |
| pretrained_weight = 0.2 | |
| keyword_weight = 0.3 | |
| else: | |
| custom_weight = 0.4 | |
| pretrained_weight = 0.4 | |
| keyword_weight = 0.2 | |
| if custom_result: | |
| results.append({ | |
| "category": custom_result["category"], | |
| "confidence": custom_result["confidence"], | |
| "weight": custom_weight | |
| }) | |
| if pretrained_result: | |
| results.append({ | |
| "category": pretrained_result["category"], | |
| "confidence": pretrained_result["confidence"], | |
| "weight": pretrained_weight | |
| }) | |
| if keyword_result: | |
| results.append({ | |
| "category": keyword_result["category"], | |
| "confidence": keyword_result["confidence"], | |
| "weight": keyword_weight | |
| }) | |
| # Weighted voting | |
| category_scores = {} | |
| for result in results: | |
| cat = result["category"] | |
| score = result["confidence"] * result["weight"] | |
| category_scores[cat] = category_scores.get(cat, 0) + score | |
| if category_scores: | |
| sorted_categories = sorted(category_scores.items(), key=lambda x: x[1], reverse=True) | |
| final_category = sorted_categories[0][0] | |
| final_confidence = category_scores[final_category] / sum(r["weight"] for r in results) | |
| if len(sorted_categories) > 1: | |
| top_cat, top_score = sorted_categories[0] | |
| second_cat, second_score = sorted_categories[1] | |
| if (second_cat == "hate_speech" and | |
| top_cat != "hate_speech" and | |
| (top_score - second_score) < 0.15 and | |
| has_patterns): | |
| final_category = "hate_speech" | |
| final_confidence = second_score / sum(r["weight"] for r in results) | |
| else: | |
| final_category = "neutral" | |
| final_confidence = 0.5 | |
| # Generate reasoning | |
| reasons = [] | |
| if has_patterns: | |
| reasons.append(f"Detected hate speech patterns in text structure") | |
| if custom_result and custom_result["category"] == "hate_speech": | |
| reasons.append(f"Custom model detected {custom_result['category']} with {custom_result['confidence']:.2%} confidence") | |
| if pretrained_result: | |
| if pretrained_result.get("translated"): | |
| reasons.append(f"Pretrained model analyzed translated text and identified {pretrained_result['category']}") | |
| elif pretrained_result["category"] != "neutral": | |
| reasons.append(f"Pretrained model identified {pretrained_result['category']} patterns") | |
| if keyword_result and keyword_result.get("detected_keywords"): | |
| reasons.append(f"Found {len(keyword_result['detected_keywords'])} hate/offensive keywords") | |
| if not reasons: | |
| reasons = ["Classification based on content analysis"] | |
| all_keywords = keyword_result.get("detected_keywords", []) | |
| highlighted_phrases = highlight_keywords(text, all_keywords) if all_keywords else [] | |
| return { | |
| "ensemble": { | |
| "category": final_category, | |
| "confidence": float(final_confidence), | |
| "reasons": reasons, | |
| "weights_used": { | |
| "custom_model": custom_weight, | |
| "pretrained_model": pretrained_weight, | |
| "keyword_analysis": keyword_weight | |
| } | |
| }, | |
| "custom_model": { | |
| "available": custom_result is not None, | |
| "category": custom_result["category"] if custom_result else None, | |
| "confidence": custom_result["confidence"] if custom_result else None, | |
| "method": custom_result.get("method") if custom_result else None, | |
| "raw_prediction": custom_result.get("raw_prediction") if custom_result else None | |
| }, | |
| "pretrained_model": { | |
| "available": pretrained_result is not None, | |
| "category": pretrained_result["category"] if pretrained_result else None, | |
| "confidence": pretrained_result["confidence"] if pretrained_result else None, | |
| "method": pretrained_result.get("method") if pretrained_result else None, | |
| "raw_labels": pretrained_result.get("raw_labels") if pretrained_result else None, | |
| "translated": pretrained_result.get("translated", False) if pretrained_result else False, | |
| "translated_text": pretrained_result.get("translated_text") if pretrained_result else None | |
| }, | |
| "keyword_analysis": { | |
| "available": True, | |
| "category": keyword_result["category"], | |
| "confidence": keyword_result["confidence"], | |
| "method": keyword_result["method"], | |
| "detected_keywords": keyword_result.get("detected_keywords", []), | |
| "hate_count": keyword_result.get("hate_count", 0), | |
| "offensive_count": keyword_result.get("offensive_count", 0), | |
| "pattern_matches": keyword_result.get("pattern_matches", 0) | |
| }, | |
| "highlighted_phrases": highlighted_phrases, | |
| "detected_language": language, | |
| "original_text": text[:200] + "..." if len(text) > 200 else text | |
| } |