Spaces:
Sleeping
Sleeping
File size: 4,343 Bytes
bac4585 8d33417 9412c2e 8d33417 6472a19 60b73e6 6472a19 8d33417 bac4585 8d33417 bac4585 8d33417 bac4585 8d33417 bac4585 8d33417 bac4585 8d33417 d3d68bb b4d7111 d3d68bb ccd085d 8d33417 d3d68bb 8d33417 d3d68bb bac4585 8d33417 b4d7111 d3d68bb 8d33417 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import torch
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
class Reranker:
def __init__(self, use_float16: bool = False):
"""
Inicijalizacija reranker modela
Args:
use_float16: Koristi float16 za manju memoriju i brži inference (default: False)
"""
# Koristi 0.6B model umesto 4B zbog manje memorije
self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-0.6B", padding_side='left')
# Učitaj model sa opcionalnom float16 preciznosti
if use_float16:
self.model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-Reranker-0.6B",
torch_dtype=torch.float16
).eval()
else:
self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B").eval()
# Cache prefix i suffix tokene (ne mijenjaju se)
prefix = (
"<|im_start|>system\n"
"Procijeni da li dati Dokument adekvatno odgovara na Upit na osnovu pravne instrukcije. "
"Dokument treba da bude relevantan, tačan i u skladu sa važećim pravnim propisima i standardima. "
"Odgovor mora biti striktno \"da\" ako ispunjava uslove, ili \"ne\" ako ne ispunjava.\n"
"<|im_end|>\n"
"<|im_start|>user\n"
)
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
self.prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
self.suffix_tokens = self.tokenizer.encode(suffix, add_special_tokens=False)
# Cache token IDs za yes/no (ne mijenjaju se)
self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
self.max_length = 2048
def format_instruction(self, instruction, query, doc):
if instruction is None:
instruction = 'Given a web search query, retrieve relevant passages that answer the query'
return f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}"
def process_inputs(self, pairs):
"""Procesira input parove (query, document) za model"""
inputs = self.tokenizer(
pairs,
padding=False,
truncation='longest_first',
return_attention_mask=False,
max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens)
)
# Dodaj cache-ovane prefix i suffix tokene
for i, ele in enumerate(inputs['input_ids']):
inputs['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens
inputs = self.tokenizer.pad(
inputs,
padding=True,
return_tensors="pt",
max_length=self.max_length
)
for key in inputs:
inputs[key] = inputs[key].to(self.model.device)
return inputs
@torch.no_grad()
def compute_logits(self, queries, documents, top_k: int = 3):
"""
Izračunaj reranking skorove i vrati top_k rezultata
Args:
queries: Lista query-ja (obično isti query ponovljen)
documents: Lista dokumenata za reranking
top_k: Broj najboljih rezultata (default: 3)
Returns:
Lista tuple-ova: [(score, query, document), ...] sortirano po skoru
"""
task = 'Na osnovu datog upita, vrati najrelevantije rezultate koje odgovaraju upitu'
pairs = [self.format_instruction(task, query, doc) for query, doc in zip(queries, documents)]
inputs = self.process_inputs(pairs)
# Izračunaj skorove
batch_scores = self.model(**inputs).logits[:, -1, :]
true_vector = batch_scores[:, self.token_true_id]
false_vector = batch_scores[:, self.token_false_id]
batch_scores = torch.stack([false_vector, true_vector], dim=1)
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
scores = batch_scores[:, 1].exp().tolist()
# Sortiraj i vrati top_k
results = list(zip(scores, queries, documents))
results.sort(key=lambda x: x[0], reverse=True)
return results[:top_k] |