File size: 6,751 Bytes
b7f3196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""

CLI Interface for Healthcare Reason Classification



This provides a command-line interface for testing and using the 

healthcare reason classifier system with real healthcare data.

"""

import argparse
import json
import sys
from pathlib import Path

# Add project root to path
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

def classify_single_query(query: str):
    """Classify a single healthcare reason query and display results."""
    
    print(f"Query: {query}")
    print("-" * 50)
    
    try:
        # Import the reason inference module
        sys.path.append('classifier')
        from classifier.reason.infer_reason import predict_single_reason
        
        # Get prediction
        result = predict_single_reason(query)
        
        print(f"Primary Classification: {result['category']}")
        print(f"Confidence: {result['confidence']:.4f}")
        
        # Show all category probabilities
        print(f"\nAll Category Probabilities:")
        
        # Sort by probability
        sorted_probs = sorted(result['probabilities'].items(), 
                            key=lambda x: x[1], reverse=True)
        
        for category, prob in sorted_probs:
            print(f"  {category}: {prob:.4f}")
            
    except Exception as e:
        print(f"Error: {e}")
        print("Note: Make sure the reason classifier is trained")
        return False
    
    return True

def classify_batch_queries(queries_file: str, output_file: str = None):
    """Classify multiple queries from a file."""
    
    try:
        # Read queries
        with open(queries_file, 'r') as f:
            if queries_file.endswith('.json'):
                data = json.load(f)
                if isinstance(data, list):
                    queries = data
                else:
                    queries = data.get('queries', [])
            else:
                queries = [line.strip() for line in f if line.strip()]
        
        print(f"Processing {len(queries)} healthcare reason queries...")
        
        # Import the reason inference module
        sys.path.append('classifier')
        from classifier.reason.infer_reason import predict_single_reason
        
        results = []
        for i, query in enumerate(queries, 1):
            print(f"\n{i}. Query: {query}")
            
            result = predict_single_reason(query)
            results.append(result)
            
            print(f"   Category: {result['category']} (confidence: {result['confidence']:.3f})")
        
        # Save results if output file specified
        if output_file:
            output_data = {
                'queries': queries,
                'predictions': results,
                'summary': {
                    'total_queries': len(queries),
                    'categories': {}
                }
            }
            
            # Count categories
            for result in results:
                cat = result['category']
                output_data['summary']['categories'][cat] = output_data['summary']['categories'].get(cat, 0) + 1
            
            with open(output_file, 'w') as f:
                json.dump(output_data, f, indent=2)
            
            print(f"\nResults saved to {output_file}")
        
        # Show summary
        category_counts = {}
        for result in results:
            cat = result['category']
            category_counts[cat] = category_counts.get(cat, 0) + 1
        
        print(f"\nSummary:")
        for category, count in sorted(category_counts.items()):
            percentage = (count / len(queries)) * 100
            print(f"  {category}: {count} queries ({percentage:.1f}%)")
            
    except Exception as e:
        print(f"Error processing batch queries: {e}")
        return False
    
    return True

def interactive_mode():
    """Interactive mode for testing healthcare reason queries."""
    
    print("Healthcare Reason Classifier - Interactive Mode")
    print("=" * 50)
    print("Enter healthcare reason queries to classify (type 'quit' to exit)")
    print()
    print("Example queries to try:")
    print("  • 'I have heel pain when I walk'")
    print("  • 'My toenail is ingrown and infected'")
    print("  • 'I need routine foot care'")
    print("  • 'I sprained my ankle playing basketball'")
    print("  • 'I have plantar fasciitis'")
    print("  • 'I need a cortisone injection'")
    print()
    
    while True:
        try:
            user_input = input(">>> ").strip()
            
            if user_input.lower() == 'quit':
                break
            
            if user_input:
                classify_single_query(user_input)
                print()
            
        except KeyboardInterrupt:
            break
        except Exception as e:
            print(f"Error: {e}")
            print()

def main():
    parser = argparse.ArgumentParser(
        description='Healthcare Reason Classification CLI',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""

Examples:

  # Classify a single healthcare reason query

  python cli/reason_classifier_cli_new.py "I have heel pain"

  

  # Batch process queries from file

  python cli/reason_classifier_cli_new.py --batch reason_queries.txt --output results.json

  

  # Interactive mode

  python cli/reason_classifier_cli_new.py --interactive

        """
    )
    
    parser.add_argument('query', nargs='?', help='Healthcare reason query to classify')
    parser.add_argument('--batch', type=str, help='File containing queries to process')
    parser.add_argument('--output', type=str, help='Output file for batch results')
    parser.add_argument('--interactive', action='store_true', 
                       help='Start interactive mode')
    
    args = parser.parse_args()
    
    # Interactive mode
    if args.interactive:
        interactive_mode()
        return 0
    
    # Batch processing
    if args.batch:
        if not Path(args.batch).exists():
            print(f"Error: Batch file does not exist: {args.batch}")
            return 1
        
        success = classify_batch_queries(args.batch, args.output)
        return 0 if success else 1
    
    # Single query processing
    if args.query:
        success = classify_single_query(args.query)
        return 0 if success else 1
    
    # No arguments provided
    parser.print_help()
    return 1

if __name__ == "__main__":
    sys.exit(main())