Spaces:
Sleeping
Sleeping
| # core/classification.py (NEW FILE) | |
| """Improved hierarchical classification with LLM fallback.""" | |
| import os | |
| from typing import Dict, Optional, List | |
| from openai import OpenAI | |
| import json | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class ImprovedHierarchicalClassifier: | |
| """Enhanced classifier with LLM-based fallback.""" | |
| def __init__(self, hierarchy_name: str, use_llm: bool = True): | |
| """ | |
| Initialize improved classifier. | |
| Args: | |
| hierarchy_name: Name of hierarchy to use | |
| use_llm: Whether to use LLM for classification | |
| """ | |
| from core.utils import load_hierarchy | |
| self.hierarchy = load_hierarchy(hierarchy_name) | |
| self.hierarchy_name = hierarchy_name | |
| self.use_llm = use_llm | |
| self._build_keyword_maps() | |
| # Initialize OpenAI client if using LLM | |
| if self.use_llm: | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if api_key: | |
| self.client = OpenAI(api_key=api_key) | |
| else: | |
| logger.warning("No OpenAI API key found, falling back to keyword matching") | |
| self.use_llm = False | |
| def _build_keyword_maps(self) -> None: | |
| """Build keyword mappings for classification.""" | |
| self.level1_keywords = {} | |
| self.level2_keywords = {} | |
| self.level3_keywords = {} | |
| # Level 1: domain keywords | |
| for domain in self.hierarchy['levels'][0]['values']: | |
| keywords = domain.lower().split() | |
| self.level1_keywords[domain] = keywords | |
| # Level 2: section keywords | |
| if 'mapping' in self.hierarchy['levels'][1]: | |
| for domain, sections in self.hierarchy['levels'][1]['mapping'].items(): | |
| for section in sections: | |
| keywords = section.lower().split() | |
| self.level2_keywords[section] = keywords | |
| # Level 3: topic keywords | |
| if 'mapping' in self.hierarchy['levels'][2]: | |
| for section, topics in self.hierarchy['levels'][2]['mapping'].items(): | |
| for topic in topics: | |
| keywords = topic.lower().split() | |
| self.level3_keywords[topic] = keywords | |
| def classify_with_llm(self, text: str) -> Dict[str, any]: | |
| """ | |
| Classify using LLM with structured output. | |
| Args: | |
| text: Text to classify | |
| Returns: | |
| Classification with confidence scores | |
| """ | |
| if not self.use_llm: | |
| return self._fallback_classification(text) | |
| try: | |
| # Build prompt with hierarchy structure | |
| domains = self.hierarchy['levels'][0]['values'] | |
| doc_types = self.hierarchy.get('doc_types', []) | |
| prompt = f"""You are a document classification expert. Classify the following text into the appropriate categories. | |
| **Available Domains:** | |
| {', '.join(domains)} | |
| **Available Document Types:** | |
| {', '.join(doc_types)} | |
| **Text to classify (first 800 characters):** | |
| {text[:800]} | |
| Return a JSON object with: | |
| - "level1": the most appropriate domain | |
| - "confidence": confidence score (0.0-1.0) | |
| - "doc_type": the document type | |
| - "reasoning": brief explanation | |
| Example response: | |
| {{ | |
| "level1": "Clinical Care", | |
| "confidence": 0.85, | |
| "doc_type": "protocol", | |
| "reasoning": "Text discusses patient procedures" | |
| }}""" | |
| response = self.client.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=[ | |
| {"role": "system", "content": "You are a precise document classifier. Always respond with valid JSON."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| temperature=0.1, | |
| max_tokens=200 | |
| ) | |
| # Parse LLM response | |
| result_text = response.choices[0].message.content.strip() | |
| # Extract JSON from response (handle markdown code blocks) | |
| if "```json" in result_text: | |
| result_text = result_text.split("```json")[1].split("```")[0].strip() | |
| elif "```" in result_text: | |
| result_text = result_text.split("```")[1].split("```")[0].strip() | |
| result = json.loads(result_text) | |
| # Validate level1 is in available domains | |
| if result.get("level1") not in domains: | |
| logger.warning(f"LLM returned invalid domain: {result.get('level1')}") | |
| return self._fallback_classification(text) | |
| # Add level2 and level3 based on level1 | |
| level1 = result["level1"] | |
| level2 = self._classify_level2(text.lower(), level1) | |
| level3 = self._classify_level3(text.lower(), level2) | |
| return { | |
| "level1": level1, | |
| "level2": level2, | |
| "level3": level3, | |
| "doc_type": result.get("doc_type", self._infer_doc_type(text.lower())), | |
| "confidence": result.get("confidence", 0.5), | |
| "method": "llm" | |
| } | |
| except Exception as e: | |
| logger.error(f"LLM classification failed: {str(e)}") | |
| return self._fallback_classification(text) | |
| def _fallback_classification(self, text: str) -> Dict[str, any]: | |
| """Fallback to keyword-based classification.""" | |
| text_lower = text.lower() | |
| level1 = self._classify_level1(text_lower) | |
| level2 = self._classify_level2(text_lower, level1) | |
| level3 = self._classify_level3(text_lower, level2) | |
| doc_type = self._infer_doc_type(text_lower) | |
| return { | |
| "level1": level1, | |
| "level2": level2, | |
| "level3": level3, | |
| "doc_type": doc_type, | |
| "confidence": 0.3, # Low confidence for keyword matching | |
| "method": "keyword" | |
| } | |
| def _classify_level1(self, text: str) -> str: | |
| """Classify domain (level 1) using keywords.""" | |
| scores = {} | |
| for domain, keywords in self.level1_keywords.items(): | |
| # Count keyword matches | |
| score = sum(1 for kw in keywords if kw in text) | |
| # Boost score if exact domain name appears | |
| if domain.lower() in text: | |
| score += 5 | |
| scores[domain] = score | |
| if max(scores.values()) > 0: | |
| return max(scores, key=scores.get) | |
| return self.hierarchy['levels'][0]['values'][0] | |
| def _classify_level2(self, text: str, level1: str) -> str: | |
| """Classify section (level 2) based on level 1.""" | |
| if 'mapping' not in self.hierarchy['levels'][1]: | |
| return "Unknown" | |
| sections = self.hierarchy['levels'][1]['mapping'].get(level1, []) | |
| if not sections: | |
| return "Unknown" | |
| scores = {} | |
| for section in sections: | |
| keywords = self.level2_keywords.get(section, []) | |
| score = sum(1 for kw in keywords if kw in text) | |
| if section.lower() in text: | |
| score += 3 | |
| scores[section] = score | |
| if max(scores.values(), default=0) > 0: | |
| return max(scores, key=scores.get) | |
| return sections[0] | |
| def _classify_level3(self, text: str, level2: str) -> str: | |
| """Classify topic (level 3) based on level 2.""" | |
| if 'mapping' not in self.hierarchy['levels'][2]: | |
| return "Unknown" | |
| topics = self.hierarchy['levels'][2]['mapping'].get(level2, []) | |
| if not topics: | |
| return "Unknown" | |
| scores = {} | |
| for topic in topics: | |
| keywords = self.level3_keywords.get(topic, []) | |
| score = sum(1 for kw in keywords if kw in text) | |
| if topic.lower() in text: | |
| score += 3 | |
| scores[topic] = score | |
| if max(scores.values(), default=0) > 0: | |
| return max(scores, key=scores.get) | |
| return topics[0] | |
| def _infer_doc_type(self, text: str) -> str: | |
| """Infer document type from content.""" | |
| doc_types = self.hierarchy.get('doc_types', ['unknown']) | |
| type_keywords = { | |
| 'policy': ['policy', 'regulation', 'rule', 'requirement', 'must', 'shall'], | |
| 'manual': ['manual', 'guide', 'instruction', 'procedure', 'how to', 'step'], | |
| 'report': ['report', 'analysis', 'findings', 'results', 'summary', 'conclusion'], | |
| 'protocol': ['protocol', 'standard', 'specification', 'guideline'], | |
| 'faq': ['faq', 'question', 'answer', 'q&a', 'frequently asked'], | |
| 'agreement': ['agreement', 'contract', 'terms', 'conditions'], | |
| 'guideline': ['guideline', 'recommendation', 'best practice', 'should'], | |
| 'paper': ['abstract', 'introduction', 'methodology', 'conclusion', 'references'], | |
| 'tutorial': ['tutorial', 'example', 'walkthrough', 'demo', 'lesson'], | |
| 'specification': ['specification', 'requirement', 'definition', 'spec'], | |
| 'record': ['record', 'log', 'entry', 'note', 'documentation'] | |
| } | |
| scores = {dt: 0 for dt in doc_types} | |
| for doc_type in doc_types: | |
| keywords = type_keywords.get(doc_type, [doc_type]) | |
| score = sum(text.count(kw) for kw in keywords) | |
| scores[doc_type] = score | |
| if max(scores.values()) > 0: | |
| return max(scores, key=scores.get) | |
| return doc_types[0] | |
| def classify_text(self, text: str, doc_type: Optional[str] = None) -> Dict[str, str]: | |
| """ | |
| Classify text into hierarchical categories. | |
| Args: | |
| text: Text to classify | |
| doc_type: Optional document type override | |
| Returns: | |
| Dictionary with level1, level2, level3, and doc_type | |
| """ | |
| # Try LLM classification first | |
| if self.use_llm: | |
| result = self.classify_with_llm(text) | |
| else: | |
| result = self._fallback_classification(text) | |
| # Override doc_type if provided | |
| if doc_type: | |
| result["doc_type"] = doc_type | |
| logger.info(f"Classification: {result['level1']} > {result['level2']} > {result['level3']} " | |
| f"({result['doc_type']}) [method: {result.get('method', 'unknown')}, " | |
| f"confidence: {result.get('confidence', 0):.2f}]") | |
| return result |