AryanPrakhar commited on
Commit
48c6574
·
verified ·
1 Parent(s): b27e979

Add inference.py

Browse files
Files changed (1) hide show
  1. concept-classifier/inference.py +179 -0
concept-classifier/inference.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Programming Paradigm Classification - Inference Script
3
+ Uses trained SVM classifier and sentence embeddings for predictions
4
+ """
5
+
6
+ import pickle
7
+ import numpy as np
8
+ from sentence_transformers import SentenceTransformer
9
+ import sys
10
+
11
+
12
+ class ProgrammingParadigmClassifier:
13
+ """Classifier for programming paradigm predictions."""
14
+
15
+ def __init__(self, classifier_path='svm_classifier.pkl',
16
+ model_name_path='sentence_model_name.txt',
17
+ confidence_threshold=0.55):
18
+ """Initialize classifier and embedding model."""
19
+ print("Loading trained SVM classifier...")
20
+ with open(classifier_path, 'rb') as f:
21
+ self.classifier = pickle.load(f)
22
+
23
+ # Load the model name that was used during training
24
+ print("Reading embedding model name from training...")
25
+ with open(model_name_path, 'r') as f:
26
+ model_name = f.read().strip()
27
+
28
+ print(f"Loading sentence embedding model: {model_name}...")
29
+ self.model = SentenceTransformer(model_name)
30
+ self.confidence_threshold = confidence_threshold
31
+ print(f"Models loaded! (Confidence threshold: {confidence_threshold})\n")
32
+
33
+ def predict(self, text):
34
+ """Predict programming paradigm for given text with uncertainty handling."""
35
+ # Generate embedding
36
+ embedding = self.model.encode([text])
37
+
38
+ # Get probabilities (handle both CalibratedClassifierCV and LinearSVC)
39
+ if hasattr(self.classifier, 'predict_proba'):
40
+ # CalibratedClassifierCV - has real probabilities
41
+ probs = self.classifier.predict_proba(embedding)[0]
42
+ else:
43
+ # LinearSVC - use decision_function and convert to probabilities
44
+ scores = self.classifier.decision_function(embedding)[0]
45
+ # Softmax to convert scores to probabilities
46
+ exp_scores = np.exp(scores - np.max(scores))
47
+ probs = exp_scores / exp_scores.sum()
48
+
49
+ prob_dict = dict(zip(self.classifier.classes_, probs))
50
+
51
+ # Get top two probabilities for margin calculation
52
+ sorted_indices = np.argsort(probs)[::-1]
53
+ sorted_probs = probs[sorted_indices]
54
+ max_prob = sorted_probs[0]
55
+ second_max = sorted_probs[1] if len(sorted_probs) > 1 else 0.0
56
+ margin = max_prob - second_max
57
+
58
+ # Get class names for top two
59
+ top_classes = self.classifier.classes_[sorted_indices]
60
+ top_class = top_classes[0]
61
+ second_class = top_classes[1] if len(top_classes) > 1 else None
62
+
63
+
64
+ if max_prob > 0.25 and second_max > 0.25 and margin < 0.08:
65
+ # Both classes are viable - return both
66
+ prediction = f"{top_class} or {second_class}"
67
+ elif max_prob < 0.30 or margin < 0.10:
68
+ prediction = "unclear"
69
+ else:
70
+ prediction = top_class
71
+
72
+ return prediction, prob_dict, max_prob
73
+
74
+ def predict_batch(self, texts):
75
+ """Predict programming paradigms for multiple texts."""
76
+ results = []
77
+ for text in texts:
78
+ prediction, probs, max_prob = self.predict(text)
79
+ results.append({
80
+ 'text': text,
81
+ 'prediction': prediction,
82
+ 'probabilities': probs,
83
+ 'confidence': max_prob
84
+ })
85
+ return results
86
+
87
+ def display_prediction(self, text, prediction, probs, max_prob):
88
+ """Display prediction results in formatted output."""
89
+ print(f"\nInput: {text[:100]}{'...' if len(text) > 100 else ''}")
90
+
91
+ # Format output for dual or single predictions
92
+ if " or " in str(prediction):
93
+ print(f"Predicted Paradigm: {prediction} (ambiguous - close call!)")
94
+ elif prediction == "unclear":
95
+ print(f"Predicted Paradigm: {prediction} (too uncertain)")
96
+ else:
97
+ print(f"Predicted Paradigm: {prediction} (confident)")
98
+
99
+ # Get top 2 classes for margin display
100
+ sorted_items = sorted(probs.items(), key=lambda x: x[1], reverse=True)
101
+ top_class, top_prob = sorted_items[0]
102
+ second_class, second_prob = sorted_items[1] if len(sorted_items) > 1 else (None, 0.0)
103
+ margin = top_prob - second_prob
104
+
105
+ print(f"Max: {top_class} ({top_prob:.3f}), 2nd: {second_class} ({second_prob:.3f}), Margin: {margin:.3f}")
106
+ print("Class Probabilities:")
107
+ for label, prob in sorted_items:
108
+ print(f" {label:12s}: {prob:7.3f}")
109
+ print("-" * 70)
110
+
111
+
112
+ def main():
113
+ """Main inference pipeline."""
114
+ print("=" * 70)
115
+ print("Programming Paradigm Classification - Inference")
116
+ print("=" * 70)
117
+
118
+ # Initialize classifier
119
+ clf = ProgrammingParadigmClassifier()
120
+
121
+ # Example texts for inference
122
+ test_texts = [
123
+ "How do I make this function pure without any side effects?",
124
+ "Why does my class hierarchy have so many levels of inheritance?",
125
+ "What's the best way to center a div in CSS?",
126
+ "This function just loops through the array and updates each element in place.",
127
+ "I'm using lambda functions to transform this list with map and filter.",
128
+ "How do I properly encapsulate private variables in my class?",
129
+ "What's the most efficient way to iterate through this data structure?",
130
+ "Can I use functional composition to chain these operations?"
131
+ ]
132
+
133
+ # Run inference on all examples
134
+ for text in test_texts:
135
+ prediction, probs, max_prob = clf.predict(text)
136
+ clf.display_prediction(text, prediction, probs, max_prob)
137
+
138
+ print("\n" + "=" * 70)
139
+ print("Inference complete!")
140
+ print("=" * 70)
141
+
142
+
143
+ def interactive_mode():
144
+ """Run classifier in interactive mode."""
145
+ print("=" * 70)
146
+ print("Programming Paradigm Classifier - Interactive Mode")
147
+ print("=" * 70)
148
+ print("Type 'quit' to exit\n")
149
+
150
+ # Initialize classifier
151
+ clf = ProgrammingParadigmClassifier()
152
+
153
+ while True:
154
+ try:
155
+ text = input("\nEnter text to classify (or 'quit' to exit): ").strip()
156
+
157
+ if text.lower() == 'quit':
158
+ print("Exiting...")
159
+ break
160
+
161
+ if not text:
162
+ print("Please enter some text.")
163
+ continue
164
+
165
+ prediction, probs, max_prob = clf.predict(text)
166
+ clf.display_prediction(text, prediction, probs, max_prob)
167
+
168
+ except KeyboardInterrupt:
169
+ print("\n\nExiting...")
170
+ break
171
+ except Exception as e:
172
+ print(f"Error: {e}")
173
+
174
+
175
+ if __name__ == "__main__":
176
+ if len(sys.argv) > 1 and sys.argv[1] == '--interactive':
177
+ interactive_mode()
178
+ else:
179
+ main()