Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM | |
| class Reranker: | |
| def __init__(self, use_float16: bool = False): | |
| """ | |
| Inicijalizacija reranker modela | |
| Args: | |
| use_float16: Koristi float16 za manju memoriju i brži inference (default: False) | |
| """ | |
| # Koristi 0.6B model umesto 4B zbog manje memorije | |
| self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-0.6B", padding_side='left') | |
| # Učitaj model sa opcionalnom float16 preciznosti | |
| if use_float16: | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| "Qwen/Qwen3-Reranker-0.6B", | |
| torch_dtype=torch.float16 | |
| ).eval() | |
| else: | |
| self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B").eval() | |
| # Cache prefix i suffix tokene (ne mijenjaju se) | |
| prefix = ( | |
| "<|im_start|>system\n" | |
| "Procijeni da li dati Dokument adekvatno odgovara na Upit na osnovu pravne instrukcije. " | |
| "Dokument treba da bude relevantan, tačan i u skladu sa važećim pravnim propisima i standardima. " | |
| "Odgovor mora biti striktno \"da\" ako ispunjava uslove, ili \"ne\" ako ne ispunjava.\n" | |
| "<|im_end|>\n" | |
| "<|im_start|>user\n" | |
| ) | |
| suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" | |
| self.prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False) | |
| self.suffix_tokens = self.tokenizer.encode(suffix, add_special_tokens=False) | |
| # Cache token IDs za yes/no (ne mijenjaju se) | |
| self.token_false_id = self.tokenizer.convert_tokens_to_ids("no") | |
| self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes") | |
| self.max_length = 2048 | |
| def format_instruction(self, instruction, query, doc): | |
| if instruction is None: | |
| instruction = 'Given a web search query, retrieve relevant passages that answer the query' | |
| return f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}" | |
| def process_inputs(self, pairs): | |
| """Procesira input parove (query, document) za model""" | |
| inputs = self.tokenizer( | |
| pairs, | |
| padding=False, | |
| truncation='longest_first', | |
| return_attention_mask=False, | |
| max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens) | |
| ) | |
| # Dodaj cache-ovane prefix i suffix tokene | |
| for i, ele in enumerate(inputs['input_ids']): | |
| inputs['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens | |
| inputs = self.tokenizer.pad( | |
| inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| max_length=self.max_length | |
| ) | |
| for key in inputs: | |
| inputs[key] = inputs[key].to(self.model.device) | |
| return inputs | |
| def compute_logits(self, queries, documents, top_k: int = 3): | |
| """ | |
| Izračunaj reranking skorove i vrati top_k rezultata | |
| Args: | |
| queries: Lista query-ja (obično isti query ponovljen) | |
| documents: Lista dokumenata za reranking | |
| top_k: Broj najboljih rezultata (default: 3) | |
| Returns: | |
| Lista tuple-ova: [(score, query, document), ...] sortirano po skoru | |
| """ | |
| task = 'Na osnovu datog upita, vrati najrelevantije rezultate koje odgovaraju upitu' | |
| pairs = [self.format_instruction(task, query, doc) for query, doc in zip(queries, documents)] | |
| inputs = self.process_inputs(pairs) | |
| # Izračunaj skorove | |
| batch_scores = self.model(**inputs).logits[:, -1, :] | |
| true_vector = batch_scores[:, self.token_true_id] | |
| false_vector = batch_scores[:, self.token_false_id] | |
| batch_scores = torch.stack([false_vector, true_vector], dim=1) | |
| batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) | |
| scores = batch_scores[:, 1].exp().tolist() | |
| # Sortiraj i vrati top_k | |
| results = list(zip(scores, queries, documents)) | |
| results.sort(key=lambda x: x[0], reverse=True) | |
| return results[:top_k] |