# core/classification.py (NEW FILE) """Improved hierarchical classification with LLM fallback.""" import os from typing import Dict, Optional, List from openai import OpenAI import json import logging logger = logging.getLogger(__name__) class ImprovedHierarchicalClassifier: """Enhanced classifier with LLM-based fallback.""" def __init__(self, hierarchy_name: str, use_llm: bool = True): """ Initialize improved classifier. Args: hierarchy_name: Name of hierarchy to use use_llm: Whether to use LLM for classification """ from core.utils import load_hierarchy self.hierarchy = load_hierarchy(hierarchy_name) self.hierarchy_name = hierarchy_name self.use_llm = use_llm self._build_keyword_maps() # Initialize OpenAI client if using LLM if self.use_llm: api_key = os.getenv("OPENAI_API_KEY") if api_key: self.client = OpenAI(api_key=api_key) else: logger.warning("No OpenAI API key found, falling back to keyword matching") self.use_llm = False def _build_keyword_maps(self) -> None: """Build keyword mappings for classification.""" self.level1_keywords = {} self.level2_keywords = {} self.level3_keywords = {} # Level 1: domain keywords for domain in self.hierarchy['levels'][0]['values']: keywords = domain.lower().split() self.level1_keywords[domain] = keywords # Level 2: section keywords if 'mapping' in self.hierarchy['levels'][1]: for domain, sections in self.hierarchy['levels'][1]['mapping'].items(): for section in sections: keywords = section.lower().split() self.level2_keywords[section] = keywords # Level 3: topic keywords if 'mapping' in self.hierarchy['levels'][2]: for section, topics in self.hierarchy['levels'][2]['mapping'].items(): for topic in topics: keywords = topic.lower().split() self.level3_keywords[topic] = keywords def classify_with_llm(self, text: str) -> Dict[str, any]: """ Classify using LLM with structured output. Args: text: Text to classify Returns: Classification with confidence scores """ if not self.use_llm: return self._fallback_classification(text) try: # Build prompt with hierarchy structure domains = self.hierarchy['levels'][0]['values'] doc_types = self.hierarchy.get('doc_types', []) prompt = f"""You are a document classification expert. Classify the following text into the appropriate categories. **Available Domains:** {', '.join(domains)} **Available Document Types:** {', '.join(doc_types)} **Text to classify (first 800 characters):** {text[:800]} Return a JSON object with: - "level1": the most appropriate domain - "confidence": confidence score (0.0-1.0) - "doc_type": the document type - "reasoning": brief explanation Example response: {{ "level1": "Clinical Care", "confidence": 0.85, "doc_type": "protocol", "reasoning": "Text discusses patient procedures" }}""" response = self.client.chat.completions.create( model="gpt-3.5-turbo", messages=[ {"role": "system", "content": "You are a precise document classifier. Always respond with valid JSON."}, {"role": "user", "content": prompt} ], temperature=0.1, max_tokens=200 ) # Parse LLM response result_text = response.choices[0].message.content.strip() # Extract JSON from response (handle markdown code blocks) if "```json" in result_text: result_text = result_text.split("```json")[1].split("```")[0].strip() elif "```" in result_text: result_text = result_text.split("```")[1].split("```")[0].strip() result = json.loads(result_text) # Validate level1 is in available domains if result.get("level1") not in domains: logger.warning(f"LLM returned invalid domain: {result.get('level1')}") return self._fallback_classification(text) # Add level2 and level3 based on level1 level1 = result["level1"] level2 = self._classify_level2(text.lower(), level1) level3 = self._classify_level3(text.lower(), level2) return { "level1": level1, "level2": level2, "level3": level3, "doc_type": result.get("doc_type", self._infer_doc_type(text.lower())), "confidence": result.get("confidence", 0.5), "method": "llm" } except Exception as e: logger.error(f"LLM classification failed: {str(e)}") return self._fallback_classification(text) def _fallback_classification(self, text: str) -> Dict[str, any]: """Fallback to keyword-based classification.""" text_lower = text.lower() level1 = self._classify_level1(text_lower) level2 = self._classify_level2(text_lower, level1) level3 = self._classify_level3(text_lower, level2) doc_type = self._infer_doc_type(text_lower) return { "level1": level1, "level2": level2, "level3": level3, "doc_type": doc_type, "confidence": 0.3, # Low confidence for keyword matching "method": "keyword" } def _classify_level1(self, text: str) -> str: """Classify domain (level 1) using keywords.""" scores = {} for domain, keywords in self.level1_keywords.items(): # Count keyword matches score = sum(1 for kw in keywords if kw in text) # Boost score if exact domain name appears if domain.lower() in text: score += 5 scores[domain] = score if max(scores.values()) > 0: return max(scores, key=scores.get) return self.hierarchy['levels'][0]['values'][0] def _classify_level2(self, text: str, level1: str) -> str: """Classify section (level 2) based on level 1.""" if 'mapping' not in self.hierarchy['levels'][1]: return "Unknown" sections = self.hierarchy['levels'][1]['mapping'].get(level1, []) if not sections: return "Unknown" scores = {} for section in sections: keywords = self.level2_keywords.get(section, []) score = sum(1 for kw in keywords if kw in text) if section.lower() in text: score += 3 scores[section] = score if max(scores.values(), default=0) > 0: return max(scores, key=scores.get) return sections[0] def _classify_level3(self, text: str, level2: str) -> str: """Classify topic (level 3) based on level 2.""" if 'mapping' not in self.hierarchy['levels'][2]: return "Unknown" topics = self.hierarchy['levels'][2]['mapping'].get(level2, []) if not topics: return "Unknown" scores = {} for topic in topics: keywords = self.level3_keywords.get(topic, []) score = sum(1 for kw in keywords if kw in text) if topic.lower() in text: score += 3 scores[topic] = score if max(scores.values(), default=0) > 0: return max(scores, key=scores.get) return topics[0] def _infer_doc_type(self, text: str) -> str: """Infer document type from content.""" doc_types = self.hierarchy.get('doc_types', ['unknown']) type_keywords = { 'policy': ['policy', 'regulation', 'rule', 'requirement', 'must', 'shall'], 'manual': ['manual', 'guide', 'instruction', 'procedure', 'how to', 'step'], 'report': ['report', 'analysis', 'findings', 'results', 'summary', 'conclusion'], 'protocol': ['protocol', 'standard', 'specification', 'guideline'], 'faq': ['faq', 'question', 'answer', 'q&a', 'frequently asked'], 'agreement': ['agreement', 'contract', 'terms', 'conditions'], 'guideline': ['guideline', 'recommendation', 'best practice', 'should'], 'paper': ['abstract', 'introduction', 'methodology', 'conclusion', 'references'], 'tutorial': ['tutorial', 'example', 'walkthrough', 'demo', 'lesson'], 'specification': ['specification', 'requirement', 'definition', 'spec'], 'record': ['record', 'log', 'entry', 'note', 'documentation'] } scores = {dt: 0 for dt in doc_types} for doc_type in doc_types: keywords = type_keywords.get(doc_type, [doc_type]) score = sum(text.count(kw) for kw in keywords) scores[doc_type] = score if max(scores.values()) > 0: return max(scores, key=scores.get) return doc_types[0] def classify_text(self, text: str, doc_type: Optional[str] = None) -> Dict[str, str]: """ Classify text into hierarchical categories. Args: text: Text to classify doc_type: Optional document type override Returns: Dictionary with level1, level2, level3, and doc_type """ # Try LLM classification first if self.use_llm: result = self.classify_with_llm(text) else: result = self._fallback_classification(text) # Override doc_type if provided if doc_type: result["doc_type"] = doc_type logger.info(f"Classification: {result['level1']} > {result['level2']} > {result['level3']} " f"({result['doc_type']}) [method: {result.get('method', 'unknown')}, " f"confidence: {result.get('confidence', 0):.2f}]") return result