File size: 10,618 Bytes
c54dcef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
# core/classification.py (NEW FILE)
"""Improved hierarchical classification with LLM fallback."""

import os
from typing import Dict, Optional, List
from openai import OpenAI
import json
import logging

logger = logging.getLogger(__name__)


class ImprovedHierarchicalClassifier:
    """Enhanced classifier with LLM-based fallback."""
    
    def __init__(self, hierarchy_name: str, use_llm: bool = True):
        """
        Initialize improved classifier.
        
        Args:
            hierarchy_name: Name of hierarchy to use
            use_llm: Whether to use LLM for classification
        """
        from core.utils import load_hierarchy
        
        self.hierarchy = load_hierarchy(hierarchy_name)
        self.hierarchy_name = hierarchy_name
        self.use_llm = use_llm
        self._build_keyword_maps()
        
        # Initialize OpenAI client if using LLM
        if self.use_llm:
            api_key = os.getenv("OPENAI_API_KEY")
            if api_key:
                self.client = OpenAI(api_key=api_key)
            else:
                logger.warning("No OpenAI API key found, falling back to keyword matching")
                self.use_llm = False
    
    def _build_keyword_maps(self) -> None:
        """Build keyword mappings for classification."""
        self.level1_keywords = {}
        self.level2_keywords = {}
        self.level3_keywords = {}
        
        # Level 1: domain keywords
        for domain in self.hierarchy['levels'][0]['values']:
            keywords = domain.lower().split()
            self.level1_keywords[domain] = keywords
        
        # Level 2: section keywords
        if 'mapping' in self.hierarchy['levels'][1]:
            for domain, sections in self.hierarchy['levels'][1]['mapping'].items():
                for section in sections:
                    keywords = section.lower().split()
                    self.level2_keywords[section] = keywords
        
        # Level 3: topic keywords
        if 'mapping' in self.hierarchy['levels'][2]:
            for section, topics in self.hierarchy['levels'][2]['mapping'].items():
                for topic in topics:
                    keywords = topic.lower().split()
                    self.level3_keywords[topic] = keywords
    
    def classify_with_llm(self, text: str) -> Dict[str, any]:
        """
        Classify using LLM with structured output.
        
        Args:
            text: Text to classify
            
        Returns:
            Classification with confidence scores
        """
        if not self.use_llm:
            return self._fallback_classification(text)
        
        try:
            # Build prompt with hierarchy structure
            domains = self.hierarchy['levels'][0]['values']
            doc_types = self.hierarchy.get('doc_types', [])
            
            prompt = f"""You are a document classification expert. Classify the following text into the appropriate categories.

**Available Domains:**
{', '.join(domains)}

**Available Document Types:**
{', '.join(doc_types)}

**Text to classify (first 800 characters):**
{text[:800]}

Return a JSON object with:
- "level1": the most appropriate domain
- "confidence": confidence score (0.0-1.0)
- "doc_type": the document type
- "reasoning": brief explanation

Example response:
{{
  "level1": "Clinical Care",
  "confidence": 0.85,
  "doc_type": "protocol",
  "reasoning": "Text discusses patient procedures"
}}"""

            response = self.client.chat.completions.create(
                model="gpt-3.5-turbo",
                messages=[
                    {"role": "system", "content": "You are a precise document classifier. Always respond with valid JSON."},
                    {"role": "user", "content": prompt}
                ],
                temperature=0.1,
                max_tokens=200
            )
            
            # Parse LLM response
            result_text = response.choices[0].message.content.strip()
            
            # Extract JSON from response (handle markdown code blocks)
            if "```json" in result_text:
                result_text = result_text.split("```json")[1].split("```")[0].strip()
            elif "```" in result_text:
                result_text = result_text.split("```")[1].split("```")[0].strip()
            
            result = json.loads(result_text)
            
            # Validate level1 is in available domains
            if result.get("level1") not in domains:
                logger.warning(f"LLM returned invalid domain: {result.get('level1')}")
                return self._fallback_classification(text)
            
            # Add level2 and level3 based on level1
            level1 = result["level1"]
            level2 = self._classify_level2(text.lower(), level1)
            level3 = self._classify_level3(text.lower(), level2)
            
            return {
                "level1": level1,
                "level2": level2,
                "level3": level3,
                "doc_type": result.get("doc_type", self._infer_doc_type(text.lower())),
                "confidence": result.get("confidence", 0.5),
                "method": "llm"
            }
            
        except Exception as e:
            logger.error(f"LLM classification failed: {str(e)}")
            return self._fallback_classification(text)
    
    def _fallback_classification(self, text: str) -> Dict[str, any]:
        """Fallback to keyword-based classification."""
        text_lower = text.lower()
        
        level1 = self._classify_level1(text_lower)
        level2 = self._classify_level2(text_lower, level1)
        level3 = self._classify_level3(text_lower, level2)
        doc_type = self._infer_doc_type(text_lower)
        
        return {
            "level1": level1,
            "level2": level2,
            "level3": level3,
            "doc_type": doc_type,
            "confidence": 0.3,  # Low confidence for keyword matching
            "method": "keyword"
        }
    
    def _classify_level1(self, text: str) -> str:
        """Classify domain (level 1) using keywords."""
        scores = {}
        
        for domain, keywords in self.level1_keywords.items():
            # Count keyword matches
            score = sum(1 for kw in keywords if kw in text)
            # Boost score if exact domain name appears
            if domain.lower() in text:
                score += 5
            scores[domain] = score
        
        if max(scores.values()) > 0:
            return max(scores, key=scores.get)
        return self.hierarchy['levels'][0]['values'][0]
    
    def _classify_level2(self, text: str, level1: str) -> str:
        """Classify section (level 2) based on level 1."""
        if 'mapping' not in self.hierarchy['levels'][1]:
            return "Unknown"
        
        sections = self.hierarchy['levels'][1]['mapping'].get(level1, [])
        if not sections:
            return "Unknown"
        
        scores = {}
        for section in sections:
            keywords = self.level2_keywords.get(section, [])
            score = sum(1 for kw in keywords if kw in text)
            if section.lower() in text:
                score += 3
            scores[section] = score
        
        if max(scores.values(), default=0) > 0:
            return max(scores, key=scores.get)
        return sections[0]
    
    def _classify_level3(self, text: str, level2: str) -> str:
        """Classify topic (level 3) based on level 2."""
        if 'mapping' not in self.hierarchy['levels'][2]:
            return "Unknown"
        
        topics = self.hierarchy['levels'][2]['mapping'].get(level2, [])
        if not topics:
            return "Unknown"
        
        scores = {}
        for topic in topics:
            keywords = self.level3_keywords.get(topic, [])
            score = sum(1 for kw in keywords if kw in text)
            if topic.lower() in text:
                score += 3
            scores[topic] = score
        
        if max(scores.values(), default=0) > 0:
            return max(scores, key=scores.get)
        return topics[0]
    
    def _infer_doc_type(self, text: str) -> str:
        """Infer document type from content."""
        doc_types = self.hierarchy.get('doc_types', ['unknown'])
        
        type_keywords = {
            'policy': ['policy', 'regulation', 'rule', 'requirement', 'must', 'shall'],
            'manual': ['manual', 'guide', 'instruction', 'procedure', 'how to', 'step'],
            'report': ['report', 'analysis', 'findings', 'results', 'summary', 'conclusion'],
            'protocol': ['protocol', 'standard', 'specification', 'guideline'],
            'faq': ['faq', 'question', 'answer', 'q&a', 'frequently asked'],
            'agreement': ['agreement', 'contract', 'terms', 'conditions'],
            'guideline': ['guideline', 'recommendation', 'best practice', 'should'],
            'paper': ['abstract', 'introduction', 'methodology', 'conclusion', 'references'],
            'tutorial': ['tutorial', 'example', 'walkthrough', 'demo', 'lesson'],
            'specification': ['specification', 'requirement', 'definition', 'spec'],
            'record': ['record', 'log', 'entry', 'note', 'documentation']
        }
        
        scores = {dt: 0 for dt in doc_types}
        
        for doc_type in doc_types:
            keywords = type_keywords.get(doc_type, [doc_type])
            score = sum(text.count(kw) for kw in keywords)
            scores[doc_type] = score
        
        if max(scores.values()) > 0:
            return max(scores, key=scores.get)
        return doc_types[0]
    
    def classify_text(self, text: str, doc_type: Optional[str] = None) -> Dict[str, str]:
        """
        Classify text into hierarchical categories.
        
        Args:
            text: Text to classify
            doc_type: Optional document type override
            
        Returns:
            Dictionary with level1, level2, level3, and doc_type
        """
        # Try LLM classification first
        if self.use_llm:
            result = self.classify_with_llm(text)
        else:
            result = self._fallback_classification(text)
        
        # Override doc_type if provided
        if doc_type:
            result["doc_type"] = doc_type
        
        logger.info(f"Classification: {result['level1']} > {result['level2']} > {result['level3']} "
                   f"({result['doc_type']}) [method: {result.get('method', 'unknown')}, "
                   f"confidence: {result.get('confidence', 0):.2f}]")
        
        return result