senatus-dev / reranker.py
senatus123's picture
Upload reranker.py with huggingface_hub
8d33417 verified
raw
history blame
4.34 kB
import torch
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
class Reranker:
def __init__(self, use_float16: bool = False):
"""
Inicijalizacija reranker modela
Args:
use_float16: Koristi float16 za manju memoriju i brži inference (default: False)
"""
# Koristi 0.6B model umesto 4B zbog manje memorije
self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-0.6B", padding_side='left')
# Učitaj model sa opcionalnom float16 preciznosti
if use_float16:
self.model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-Reranker-0.6B",
torch_dtype=torch.float16
).eval()
else:
self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B").eval()
# Cache prefix i suffix tokene (ne mijenjaju se)
prefix = (
"<|im_start|>system\n"
"Procijeni da li dati Dokument adekvatno odgovara na Upit na osnovu pravne instrukcije. "
"Dokument treba da bude relevantan, tačan i u skladu sa važećim pravnim propisima i standardima. "
"Odgovor mora biti striktno \"da\" ako ispunjava uslove, ili \"ne\" ako ne ispunjava.\n"
"<|im_end|>\n"
"<|im_start|>user\n"
)
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
self.prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
self.suffix_tokens = self.tokenizer.encode(suffix, add_special_tokens=False)
# Cache token IDs za yes/no (ne mijenjaju se)
self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
self.max_length = 2048
def format_instruction(self, instruction, query, doc):
if instruction is None:
instruction = 'Given a web search query, retrieve relevant passages that answer the query'
return f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}"
def process_inputs(self, pairs):
"""Procesira input parove (query, document) za model"""
inputs = self.tokenizer(
pairs,
padding=False,
truncation='longest_first',
return_attention_mask=False,
max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens)
)
# Dodaj cache-ovane prefix i suffix tokene
for i, ele in enumerate(inputs['input_ids']):
inputs['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens
inputs = self.tokenizer.pad(
inputs,
padding=True,
return_tensors="pt",
max_length=self.max_length
)
for key in inputs:
inputs[key] = inputs[key].to(self.model.device)
return inputs
@torch.no_grad()
def compute_logits(self, queries, documents, top_k: int = 3):
"""
Izračunaj reranking skorove i vrati top_k rezultata
Args:
queries: Lista query-ja (obično isti query ponovljen)
documents: Lista dokumenata za reranking
top_k: Broj najboljih rezultata (default: 3)
Returns:
Lista tuple-ova: [(score, query, document), ...] sortirano po skoru
"""
task = 'Na osnovu datog upita, vrati najrelevantije rezultate koje odgovaraju upitu'
pairs = [self.format_instruction(task, query, doc) for query, doc in zip(queries, documents)]
inputs = self.process_inputs(pairs)
# Izračunaj skorove
batch_scores = self.model(**inputs).logits[:, -1, :]
true_vector = batch_scores[:, self.token_true_id]
false_vector = batch_scores[:, self.token_false_id]
batch_scores = torch.stack([false_vector, true_vector], dim=1)
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
scores = batch_scores[:, 1].exp().tolist()
# Sortiraj i vrati top_k
results = list(zip(scores, queries, documents))
results.sort(key=lambda x: x[0], reverse=True)
return results[:top_k]