hierarchical-rag-eval / core /classification.py
hh786's picture
Deployment of Hierarchical RAG system
c54dcef
# core/classification.py (NEW FILE)
"""Improved hierarchical classification with LLM fallback."""
import os
from typing import Dict, Optional, List
from openai import OpenAI
import json
import logging
logger = logging.getLogger(__name__)
class ImprovedHierarchicalClassifier:
"""Enhanced classifier with LLM-based fallback."""
def __init__(self, hierarchy_name: str, use_llm: bool = True):
"""
Initialize improved classifier.
Args:
hierarchy_name: Name of hierarchy to use
use_llm: Whether to use LLM for classification
"""
from core.utils import load_hierarchy
self.hierarchy = load_hierarchy(hierarchy_name)
self.hierarchy_name = hierarchy_name
self.use_llm = use_llm
self._build_keyword_maps()
# Initialize OpenAI client if using LLM
if self.use_llm:
api_key = os.getenv("OPENAI_API_KEY")
if api_key:
self.client = OpenAI(api_key=api_key)
else:
logger.warning("No OpenAI API key found, falling back to keyword matching")
self.use_llm = False
def _build_keyword_maps(self) -> None:
"""Build keyword mappings for classification."""
self.level1_keywords = {}
self.level2_keywords = {}
self.level3_keywords = {}
# Level 1: domain keywords
for domain in self.hierarchy['levels'][0]['values']:
keywords = domain.lower().split()
self.level1_keywords[domain] = keywords
# Level 2: section keywords
if 'mapping' in self.hierarchy['levels'][1]:
for domain, sections in self.hierarchy['levels'][1]['mapping'].items():
for section in sections:
keywords = section.lower().split()
self.level2_keywords[section] = keywords
# Level 3: topic keywords
if 'mapping' in self.hierarchy['levels'][2]:
for section, topics in self.hierarchy['levels'][2]['mapping'].items():
for topic in topics:
keywords = topic.lower().split()
self.level3_keywords[topic] = keywords
def classify_with_llm(self, text: str) -> Dict[str, any]:
"""
Classify using LLM with structured output.
Args:
text: Text to classify
Returns:
Classification with confidence scores
"""
if not self.use_llm:
return self._fallback_classification(text)
try:
# Build prompt with hierarchy structure
domains = self.hierarchy['levels'][0]['values']
doc_types = self.hierarchy.get('doc_types', [])
prompt = f"""You are a document classification expert. Classify the following text into the appropriate categories.
**Available Domains:**
{', '.join(domains)}
**Available Document Types:**
{', '.join(doc_types)}
**Text to classify (first 800 characters):**
{text[:800]}
Return a JSON object with:
- "level1": the most appropriate domain
- "confidence": confidence score (0.0-1.0)
- "doc_type": the document type
- "reasoning": brief explanation
Example response:
{{
"level1": "Clinical Care",
"confidence": 0.85,
"doc_type": "protocol",
"reasoning": "Text discusses patient procedures"
}}"""
response = self.client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a precise document classifier. Always respond with valid JSON."},
{"role": "user", "content": prompt}
],
temperature=0.1,
max_tokens=200
)
# Parse LLM response
result_text = response.choices[0].message.content.strip()
# Extract JSON from response (handle markdown code blocks)
if "```json" in result_text:
result_text = result_text.split("```json")[1].split("```")[0].strip()
elif "```" in result_text:
result_text = result_text.split("```")[1].split("```")[0].strip()
result = json.loads(result_text)
# Validate level1 is in available domains
if result.get("level1") not in domains:
logger.warning(f"LLM returned invalid domain: {result.get('level1')}")
return self._fallback_classification(text)
# Add level2 and level3 based on level1
level1 = result["level1"]
level2 = self._classify_level2(text.lower(), level1)
level3 = self._classify_level3(text.lower(), level2)
return {
"level1": level1,
"level2": level2,
"level3": level3,
"doc_type": result.get("doc_type", self._infer_doc_type(text.lower())),
"confidence": result.get("confidence", 0.5),
"method": "llm"
}
except Exception as e:
logger.error(f"LLM classification failed: {str(e)}")
return self._fallback_classification(text)
def _fallback_classification(self, text: str) -> Dict[str, any]:
"""Fallback to keyword-based classification."""
text_lower = text.lower()
level1 = self._classify_level1(text_lower)
level2 = self._classify_level2(text_lower, level1)
level3 = self._classify_level3(text_lower, level2)
doc_type = self._infer_doc_type(text_lower)
return {
"level1": level1,
"level2": level2,
"level3": level3,
"doc_type": doc_type,
"confidence": 0.3, # Low confidence for keyword matching
"method": "keyword"
}
def _classify_level1(self, text: str) -> str:
"""Classify domain (level 1) using keywords."""
scores = {}
for domain, keywords in self.level1_keywords.items():
# Count keyword matches
score = sum(1 for kw in keywords if kw in text)
# Boost score if exact domain name appears
if domain.lower() in text:
score += 5
scores[domain] = score
if max(scores.values()) > 0:
return max(scores, key=scores.get)
return self.hierarchy['levels'][0]['values'][0]
def _classify_level2(self, text: str, level1: str) -> str:
"""Classify section (level 2) based on level 1."""
if 'mapping' not in self.hierarchy['levels'][1]:
return "Unknown"
sections = self.hierarchy['levels'][1]['mapping'].get(level1, [])
if not sections:
return "Unknown"
scores = {}
for section in sections:
keywords = self.level2_keywords.get(section, [])
score = sum(1 for kw in keywords if kw in text)
if section.lower() in text:
score += 3
scores[section] = score
if max(scores.values(), default=0) > 0:
return max(scores, key=scores.get)
return sections[0]
def _classify_level3(self, text: str, level2: str) -> str:
"""Classify topic (level 3) based on level 2."""
if 'mapping' not in self.hierarchy['levels'][2]:
return "Unknown"
topics = self.hierarchy['levels'][2]['mapping'].get(level2, [])
if not topics:
return "Unknown"
scores = {}
for topic in topics:
keywords = self.level3_keywords.get(topic, [])
score = sum(1 for kw in keywords if kw in text)
if topic.lower() in text:
score += 3
scores[topic] = score
if max(scores.values(), default=0) > 0:
return max(scores, key=scores.get)
return topics[0]
def _infer_doc_type(self, text: str) -> str:
"""Infer document type from content."""
doc_types = self.hierarchy.get('doc_types', ['unknown'])
type_keywords = {
'policy': ['policy', 'regulation', 'rule', 'requirement', 'must', 'shall'],
'manual': ['manual', 'guide', 'instruction', 'procedure', 'how to', 'step'],
'report': ['report', 'analysis', 'findings', 'results', 'summary', 'conclusion'],
'protocol': ['protocol', 'standard', 'specification', 'guideline'],
'faq': ['faq', 'question', 'answer', 'q&a', 'frequently asked'],
'agreement': ['agreement', 'contract', 'terms', 'conditions'],
'guideline': ['guideline', 'recommendation', 'best practice', 'should'],
'paper': ['abstract', 'introduction', 'methodology', 'conclusion', 'references'],
'tutorial': ['tutorial', 'example', 'walkthrough', 'demo', 'lesson'],
'specification': ['specification', 'requirement', 'definition', 'spec'],
'record': ['record', 'log', 'entry', 'note', 'documentation']
}
scores = {dt: 0 for dt in doc_types}
for doc_type in doc_types:
keywords = type_keywords.get(doc_type, [doc_type])
score = sum(text.count(kw) for kw in keywords)
scores[doc_type] = score
if max(scores.values()) > 0:
return max(scores, key=scores.get)
return doc_types[0]
def classify_text(self, text: str, doc_type: Optional[str] = None) -> Dict[str, str]:
"""
Classify text into hierarchical categories.
Args:
text: Text to classify
doc_type: Optional document type override
Returns:
Dictionary with level1, level2, level3, and doc_type
"""
# Try LLM classification first
if self.use_llm:
result = self.classify_with_llm(text)
else:
result = self._fallback_classification(text)
# Override doc_type if provided
if doc_type:
result["doc_type"] = doc_type
logger.info(f"Classification: {result['level1']} > {result['level2']} > {result['level3']} "
f"({result['doc_type']}) [method: {result.get('method', 'unknown')}, "
f"confidence: {result.get('confidence', 0):.2f}]")
return result