File size: 2,672 Bytes
90c099b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
Reranker Pool for multiprocessing-safe reranker sharing

Solve the issue of FlagReranker not being pickleable in multi-process/multi-thread environment.
"""
import os
import threading
from typing import Optional, Dict
from pathlib import Path

# Global reranker storage (thread-safe)
# Note: Dictionary access is atomic in Python for simple operations,
# but we use a lock for thread-safety when modifying the dict
_reranker_pool: Dict[str, object] = {}
_reranker_lock = threading.Lock()


def get_reranker(model_path: str, use_fp16: bool = True):
    """
    Get or create reranker (thread-safe, process-shared)
    
    Performance optimization:
    - Load and cache on first call
    - Return cached instance on subsequent calls (no lock check)
    - Use double-check locking pattern, reduce lock contention
    
    Args:
        model_path: Reranker model path
        use_fp16: whether to use FP16
        
    Returns:
        FlagReranker instance
    """
    global _reranker_pool
    
    # create unique key
    key = f"{model_path}_{use_fp16}"
    
    # performance optimization: fast path check (no lock)
    if key in _reranker_pool:
        return _reranker_pool[key]
    
    # slow path: needs loading (needs lock)
    with _reranker_lock:
        # double check: other threads may have loaded while waiting for lock
        if key not in _reranker_pool:
            # lazy import, avoid importing when module is loaded
            try:
                from FlagEmbedding import FlagReranker
                
                # set environment variable to suppress progress bar
                original_verbosity = os.environ.get('TRANSFORMERS_VERBOSITY', '')
                os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
                
                try:
                    # load model
                    reranker = FlagReranker(model_path, use_fp16=use_fp16)
                    _reranker_pool[key] = reranker
                finally:
                    # restore original verbosity
                    if original_verbosity:
                        os.environ['TRANSFORMERS_VERBOSITY'] = original_verbosity
                    elif 'TRANSFORMERS_VERBOSITY' in os.environ:
                        del os.environ['TRANSFORMERS_VERBOSITY']
                
            except ImportError:
                raise ImportError(
                    "FlagEmbedding not installed. Install it with: pip install FlagEmbedding"
                )
        
        return _reranker_pool[key]


def clear_reranker_pool():
    """clear reranker pool (mainly for testing)"""
    global _reranker_pool
    with _reranker_lock:
        _reranker_pool.clear()