File size: 4,343 Bytes
bac4585
 
 
 
 
8d33417
 
 
 
 
 
 
9412c2e
 
8d33417
 
 
 
 
 
 
 
 
 
 
6472a19
 
60b73e6
6472a19
 
 
 
8d33417
bac4585
8d33417
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bac4585
8d33417
 
 
 
 
bac4585
8d33417
 
bac4585
8d33417
 
 
 
 
 
 
 
 
bac4585
 
 
 
8d33417
 
 
 
 
 
 
 
 
 
 
 
 
d3d68bb
b4d7111
d3d68bb
ccd085d
8d33417
d3d68bb
8d33417
 
d3d68bb
 
 
bac4585
8d33417
b4d7111
d3d68bb
8d33417
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM

class Reranker:
    
    def __init__(self, use_float16: bool = False):
        """
        Inicijalizacija reranker modela
        
        Args:
            use_float16: Koristi float16 za manju memoriju i brži inference (default: False)
        """
        # Koristi 0.6B model umesto 4B zbog manje memorije
        self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-0.6B", padding_side='left')
        
        # Učitaj model sa opcionalnom float16 preciznosti
        if use_float16:
            self.model = AutoModelForCausalLM.from_pretrained(
                "Qwen/Qwen3-Reranker-0.6B",
                torch_dtype=torch.float16
            ).eval()
        else:
            self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B").eval()
        
        # Cache prefix i suffix tokene (ne mijenjaju se)
        prefix = (
            "<|im_start|>system\n"
            "Procijeni da li dati Dokument adekvatno odgovara na Upit na osnovu pravne instrukcije. "
            "Dokument treba da bude relevantan, tačan i u skladu sa važećim pravnim propisima i standardima. "
            "Odgovor mora biti striktno \"da\" ako ispunjava uslove, ili \"ne\" ako ne ispunjava.\n"
            "<|im_end|>\n"
            "<|im_start|>user\n"
        )
        suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
        
        self.prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
        self.suffix_tokens = self.tokenizer.encode(suffix, add_special_tokens=False)
        
        # Cache token IDs za yes/no (ne mijenjaju se)
        self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
        self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
        
        self.max_length = 2048

    def format_instruction(self, instruction, query, doc):
        if instruction is None:
            instruction = 'Given a web search query, retrieve relevant passages that answer the query'
        return f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}"
      
    def process_inputs(self, pairs):
        """Procesira input parove (query, document) za model"""
        inputs = self.tokenizer(
            pairs,
            padding=False,
            truncation='longest_first',
            return_attention_mask=False,
            max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens)
        )
        
        # Dodaj cache-ovane prefix i suffix tokene
        for i, ele in enumerate(inputs['input_ids']):
            inputs['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens
        
        inputs = self.tokenizer.pad(
            inputs,
            padding=True,
            return_tensors="pt",
            max_length=self.max_length
        )
        
        for key in inputs:
            inputs[key] = inputs[key].to(self.model.device)
        return inputs

    @torch.no_grad()
    def compute_logits(self, queries, documents, top_k: int = 3):
        """
        Izračunaj reranking skorove i vrati top_k rezultata
        
        Args:
            queries: Lista query-ja (obično isti query ponovljen)
            documents: Lista dokumenata za reranking
            top_k: Broj najboljih rezultata (default: 3)
            
        Returns:
            Lista tuple-ova: [(score, query, document), ...] sortirano po skoru
        """
        task = 'Na osnovu datog upita, vrati najrelevantije rezultate koje odgovaraju upitu'
        pairs = [self.format_instruction(task, query, doc) for query, doc in zip(queries, documents)]
        inputs = self.process_inputs(pairs)
        
        # Izračunaj skorove
        batch_scores = self.model(**inputs).logits[:, -1, :]
        true_vector = batch_scores[:, self.token_true_id]
        false_vector = batch_scores[:, self.token_false_id]
        batch_scores = torch.stack([false_vector, true_vector], dim=1)
        batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
        scores = batch_scores[:, 1].exp().tolist()

        # Sortiraj i vrati top_k
        results = list(zip(scores, queries, documents))
        results.sort(key=lambda x: x[0], reverse=True)
        
        return results[:top_k]