File size: 10,746 Bytes
d18c374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
"""
Reranking Module

This module uses a cross-encoder model to rerank initial recommendations
and ensures balance between Knowledge (K) and Personality (P) assessments.
"""

import numpy as np
from typing import List, Dict
import logging
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


class AssessmentReranker:
    """Reranks recommendations using cross-encoder and ensures K/P balance"""
    
    def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2'):
        self.model_name = model_name
        self.model = None
        self.tokenizer = None
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        logger.info(f"Reranker using device: {self.device}")
        
    def load_model(self):
        """Load the cross-encoder model"""
        try:
            logger.info(f"Loading reranking model: {self.model_name}")
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
            self.model.to(self.device)
            self.model.eval()
            logger.info("Reranking model loaded successfully")
        except Exception as e:
            logger.error(f"Error loading model: {e}")
            raise
    
    def compute_cross_encoder_score(self, query: str, assessment_text: str) -> float:
        """Compute relevance score using cross-encoder"""
        if self.model is None:
            self.load_model()
        
        try:
            # Tokenize
            inputs = self.tokenizer(
                query,
                assessment_text,
                return_tensors='pt',
                truncation=True,
                max_length=512,
                padding=True
            )
            
            # Move to device
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            # Get score
            with torch.no_grad():
                outputs = self.model(**inputs)
                score = outputs.logits[0][0].item()
            
            return score
            
        except Exception as e:
            logger.warning(f"Error computing cross-encoder score: {e}")
            return 0.0
    
    def create_assessment_text(self, assessment: Dict) -> str:
        """Create text representation of assessment for reranking"""
        parts = []
        
        if 'assessment_name' in assessment:
            parts.append(assessment['assessment_name'])
        
        if 'category' in assessment:
            parts.append(f"Category: {assessment['category']}")
        
        if 'test_type' in assessment:
            type_full = 'Knowledge/Skill Assessment' if assessment['test_type'] == 'K' else 'Personality/Behavior Assessment'
            parts.append(type_full)
        
        if 'description' in assessment:
            parts.append(assessment['description'])
        
        return ' | '.join(parts)
    
    def rerank(self, 
              query: str, 
              candidates: List[Dict], 
              top_k: int = 10,
              alpha: float = 0.5) -> List[Dict]:
        """
        Rerank candidates using cross-encoder scores
        
        Args:
            query: Original search query
            candidates: List of candidate assessments from initial retrieval
            top_k: Number of final results to return
            alpha: Weight for combining embedding score and cross-encoder score
                  (0.0 = only cross-encoder, 1.0 = only embedding)
        
        Returns:
            Reranked list of assessments
        """
        if not candidates:
            return []
        
        logger.info(f"Reranking {len(candidates)} candidates...")
        
        # Compute cross-encoder scores
        for candidate in candidates:
            assessment_text = self.create_assessment_text(candidate)
            ce_score = self.compute_cross_encoder_score(query, assessment_text)
            
            # Store original embedding score
            embedding_score = candidate.get('score', 0.0)
            
            # Combine scores
            combined_score = alpha * embedding_score + (1 - alpha) * ce_score
            
            candidate['cross_encoder_score'] = ce_score
            candidate['embedding_score'] = embedding_score
            candidate['combined_score'] = combined_score
        
        # Sort by combined score
        reranked = sorted(candidates, key=lambda x: x['combined_score'], reverse=True)
        
        # Select top k
        reranked = reranked[:top_k]
        
        logger.info(f"Reranking complete, returning top {len(reranked)} results")
        return reranked
    
    def ensure_balance(self, 
                      assessments: List[Dict], 
                      min_k: int = 1, 
                      min_p: int = 1) -> List[Dict]:
        """
        Ensure balance between Knowledge (K) and Personality (P) assessments
        
        Args:
            assessments: List of assessments
            min_k: Minimum number of K assessments
            min_p: Minimum number of P assessments
        
        Returns:
            Balanced list of assessments
        """
        if not assessments:
            return []
        
        # Separate K and P assessments
        k_assessments = [a for a in assessments if a.get('test_type') == 'K']
        p_assessments = [a for a in assessments if a.get('test_type') == 'P']
        
        logger.info(f"Initial distribution - K: {len(k_assessments)}, P: {len(p_assessments)}")
        
        # Check if we need to adjust
        if len(k_assessments) < min_k or len(p_assessments) < min_p:
            logger.info("Adjusting to ensure minimum balance...")
            
            # Start with empty result
            result = []
            
            # Add minimum K assessments
            result.extend(k_assessments[:min_k])
            
            # Add minimum P assessments
            result.extend(p_assessments[:min_p])
            
            # Add remaining assessments by score
            remaining = [a for a in assessments if a not in result]
            remaining_sorted = sorted(remaining, key=lambda x: x.get('combined_score', x.get('score', 0)), reverse=True)
            
            # Fill up to desired total
            total_needed = len(assessments)
            result.extend(remaining_sorted[:total_needed - len(result)])
            
            # Sort final result by score
            result = sorted(result, key=lambda x: x.get('combined_score', x.get('score', 0)), reverse=True)
            
            logger.info(f"Balanced distribution - K: {len([a for a in result if a.get('test_type') == 'K'])}, "
                       f"P: {len([a for a in result if a.get('test_type') == 'P'])}")
            
            return result
        
        return assessments
    
    def rerank_and_balance(self,
                          query: str,
                          candidates: List[Dict],
                          top_k: int = 10,
                          min_k: int = 1,
                          min_p: int = 1,
                          alpha: float = 0.5) -> List[Dict]:
        """
        Rerank candidates and ensure K/P balance
        
        Args:
            query: Original search query
            candidates: List of candidate assessments
            top_k: Number of final results
            min_k: Minimum K assessments
            min_p: Minimum P assessments
            alpha: Weight for score combination
        
        Returns:
            Reranked and balanced list of assessments
        """
        # First rerank
        reranked = self.rerank(query, candidates, top_k=top_k * 2, alpha=alpha)  # Get more for balancing
        
        # Then ensure balance and trim to top_k
        balanced = self.ensure_balance(reranked, min_k=min_k, min_p=min_p)
        
        # Final trim to top_k
        final_results = balanced[:top_k]
        
        # Add rank
        for i, assessment in enumerate(final_results, 1):
            assessment['rank'] = i
        
        return final_results
    
    def normalize_scores(self, assessments: List[Dict]) -> List[Dict]:
        """Normalize scores to 0-1 range"""
        if not assessments:
            return assessments
        
        scores = [a.get('combined_score', a.get('score', 0)) for a in assessments]
        
        if not scores or max(scores) == min(scores):
            return assessments
        
        min_score = min(scores)
        max_score = max(scores)
        score_range = max_score - min_score
        
        for assessment in assessments:
            raw_score = assessment.get('combined_score', assessment.get('score', 0))
            normalized = (raw_score - min_score) / score_range
            assessment['score'] = normalized
        
        return assessments


def main():
    """Main execution function"""
    # Test the reranker
    reranker = AssessmentReranker()
    
    # Sample candidates
    candidates = [
        {
            'assessment_name': 'Java Programming Assessment',
            'category': 'Technical',
            'test_type': 'K',
            'description': 'Evaluates Java programming skills',
            'score': 0.85
        },
        {
            'assessment_name': 'Leadership Assessment',
            'category': 'Leadership',
            'test_type': 'P',
            'description': 'Evaluates leadership potential',
            'score': 0.75
        },
        {
            'assessment_name': 'Python Coding Test',
            'category': 'Technical',
            'test_type': 'K',
            'description': 'Assesses Python programming',
            'score': 0.80
        }
    ]
    
    query = "Looking for a Java developer with strong leadership skills"
    
    print("\n=== Reranking Test ===\n")
    print(f"Query: {query}\n")
    
    # Rerank and balance
    results = reranker.rerank_and_balance(query, candidates, top_k=5, min_k=1, min_p=1)
    
    print("Reranked Results:")
    for assessment in results:
        print(f"\n{assessment.get('rank', 0)}. {assessment['assessment_name']}")
        print(f"   Type: {assessment['test_type']}")
        print(f"   Embedding Score: {assessment.get('embedding_score', 0):.4f}")
        print(f"   Cross-Encoder Score: {assessment.get('cross_encoder_score', 0):.4f}")
        print(f"   Combined Score: {assessment.get('combined_score', 0):.4f}")
    
    return reranker


if __name__ == "__main__":
    main()