import torch from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM class Reranker: def __init__(self, use_float16: bool = False): """ Inicijalizacija reranker modela Args: use_float16: Koristi float16 za manju memoriju i brži inference (default: False) """ # Koristi 0.6B model umesto 4B zbog manje memorije self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-0.6B", padding_side='left') # Učitaj model sa opcionalnom float16 preciznosti if use_float16: self.model = AutoModelForCausalLM.from_pretrained( "Qwen/Qwen3-Reranker-0.6B", torch_dtype=torch.float16 ).eval() else: self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B").eval() # Cache prefix i suffix tokene (ne mijenjaju se) prefix = ( "<|im_start|>system\n" "Procijeni da li dati Dokument adekvatno odgovara na Upit na osnovu pravne instrukcije. " "Dokument treba da bude relevantan, tačan i u skladu sa važećim pravnim propisima i standardima. " "Odgovor mora biti striktno \"da\" ako ispunjava uslove, ili \"ne\" ako ne ispunjava.\n" "<|im_end|>\n" "<|im_start|>user\n" ) suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" self.prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False) self.suffix_tokens = self.tokenizer.encode(suffix, add_special_tokens=False) # Cache token IDs za yes/no (ne mijenjaju se) self.token_false_id = self.tokenizer.convert_tokens_to_ids("no") self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes") self.max_length = 2048 def format_instruction(self, instruction, query, doc): if instruction is None: instruction = 'Given a web search query, retrieve relevant passages that answer the query' return f": {instruction}\n: {query}\n: {doc}" def process_inputs(self, pairs): """Procesira input parove (query, document) za model""" inputs = self.tokenizer( pairs, padding=False, truncation='longest_first', return_attention_mask=False, max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens) ) # Dodaj cache-ovane prefix i suffix tokene for i, ele in enumerate(inputs['input_ids']): inputs['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens inputs = self.tokenizer.pad( inputs, padding=True, return_tensors="pt", max_length=self.max_length ) for key in inputs: inputs[key] = inputs[key].to(self.model.device) return inputs @torch.no_grad() def compute_logits(self, queries, documents, top_k: int = 3): """ Izračunaj reranking skorove i vrati top_k rezultata Args: queries: Lista query-ja (obično isti query ponovljen) documents: Lista dokumenata za reranking top_k: Broj najboljih rezultata (default: 3) Returns: Lista tuple-ova: [(score, query, document), ...] sortirano po skoru """ task = 'Na osnovu datog upita, vrati najrelevantije rezultate koje odgovaraju upitu' pairs = [self.format_instruction(task, query, doc) for query, doc in zip(queries, documents)] inputs = self.process_inputs(pairs) # Izračunaj skorove batch_scores = self.model(**inputs).logits[:, -1, :] true_vector = batch_scores[:, self.token_true_id] false_vector = batch_scores[:, self.token_false_id] batch_scores = torch.stack([false_vector, true_vector], dim=1) batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) scores = batch_scores[:, 1].exp().tolist() # Sortiraj i vrati top_k results = list(zip(scores, queries, documents)) results.sort(key=lambda x: x[0], reverse=True) return results[:top_k]