senatus123 commited on
Commit
8d33417
·
verified ·
1 Parent(s): 72b1c14

Upload reranker.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. reranker.py +71 -28
reranker.py CHANGED
@@ -3,18 +3,26 @@ from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
3
 
4
  class Reranker:
5
 
6
- def __init__(self):
 
 
 
 
 
 
7
  # Koristi 0.6B model umesto 4B zbog manje memorije
8
  self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-0.6B", padding_side='left')
9
- self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B").eval()
10
-
11
- def format_instruction(self,instruction, query, doc):
12
- if instruction is None:
13
- instruction = 'Given a web search query, retrieve relevant passages that answer the query'
14
- output = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(instruction=instruction,query=query, doc=doc)
15
- return output
16
-
17
- def process_inputs(self,pairs):
 
 
18
  prefix = (
19
  "<|im_start|>system\n"
20
  "Procijeni da li dati Dokument adekvatno odgovara na Upit na osnovu pravne instrukcije. "
@@ -22,40 +30,75 @@ class Reranker:
22
  "Odgovor mora biti striktno \"da\" ako ispunjava uslove, ili \"ne\" ako ne ispunjava.\n"
23
  "<|im_end|>\n"
24
  "<|im_start|>user\n"
25
- )
26
  suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
27
- prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
28
- suffix_tokens = self.tokenizer.encode(suffix, add_special_tokens=False)
29
- max_length = 2048
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  inputs = self.tokenizer(
31
- pairs, padding=False, truncation='longest_first',
32
- return_attention_mask=False, max_length=max_length - len(prefix_tokens) - len(suffix_tokens)
 
 
 
33
  )
 
 
34
  for i, ele in enumerate(inputs['input_ids']):
35
- inputs['input_ids'][i] = prefix_tokens + ele + suffix_tokens
36
- inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length)
 
 
 
 
 
 
 
37
  for key in inputs:
38
  inputs[key] = inputs[key].to(self.model.device)
39
  return inputs
40
 
41
- @torch.no_grad
42
- def compute_logits(self,queries,documents):
43
- token_false_id = self.tokenizer.convert_tokens_to_ids("no")
44
- token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
45
-
 
 
 
 
 
 
 
 
46
  task = 'Na osnovu datog upita, vrati najrelevantije rezultate koje odgovaraju upitu'
47
  pairs = [self.format_instruction(task, query, doc) for query, doc in zip(queries, documents)]
48
  inputs = self.process_inputs(pairs)
49
 
 
50
  batch_scores = self.model(**inputs).logits[:, -1, :]
51
- true_vector = batch_scores[:, token_true_id]
52
- false_vector = batch_scores[:, token_false_id]
53
  batch_scores = torch.stack([false_vector, true_vector], dim=1)
54
  batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
55
  scores = batch_scores[:, 1].exp().tolist()
56
 
 
57
  results = list(zip(scores, queries, documents))
58
  results.sort(key=lambda x: x[0], reverse=True)
59
- top_10 = results[:10]
60
-
61
- return top_10
 
3
 
4
  class Reranker:
5
 
6
+ def __init__(self, use_float16: bool = False):
7
+ """
8
+ Inicijalizacija reranker modela
9
+
10
+ Args:
11
+ use_float16: Koristi float16 za manju memoriju i brži inference (default: False)
12
+ """
13
  # Koristi 0.6B model umesto 4B zbog manje memorije
14
  self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-0.6B", padding_side='left')
15
+
16
+ # Učitaj model sa opcionalnom float16 preciznosti
17
+ if use_float16:
18
+ self.model = AutoModelForCausalLM.from_pretrained(
19
+ "Qwen/Qwen3-Reranker-0.6B",
20
+ torch_dtype=torch.float16
21
+ ).eval()
22
+ else:
23
+ self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B").eval()
24
+
25
+ # Cache prefix i suffix tokene (ne mijenjaju se)
26
  prefix = (
27
  "<|im_start|>system\n"
28
  "Procijeni da li dati Dokument adekvatno odgovara na Upit na osnovu pravne instrukcije. "
 
30
  "Odgovor mora biti striktno \"da\" ako ispunjava uslove, ili \"ne\" ako ne ispunjava.\n"
31
  "<|im_end|>\n"
32
  "<|im_start|>user\n"
33
+ )
34
  suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
35
+
36
+ self.prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
37
+ self.suffix_tokens = self.tokenizer.encode(suffix, add_special_tokens=False)
38
+
39
+ # Cache token IDs za yes/no (ne mijenjaju se)
40
+ self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
41
+ self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
42
+
43
+ self.max_length = 2048
44
+
45
+ def format_instruction(self, instruction, query, doc):
46
+ if instruction is None:
47
+ instruction = 'Given a web search query, retrieve relevant passages that answer the query'
48
+ return f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}"
49
+
50
+ def process_inputs(self, pairs):
51
+ """Procesira input parove (query, document) za model"""
52
  inputs = self.tokenizer(
53
+ pairs,
54
+ padding=False,
55
+ truncation='longest_first',
56
+ return_attention_mask=False,
57
+ max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens)
58
  )
59
+
60
+ # Dodaj cache-ovane prefix i suffix tokene
61
  for i, ele in enumerate(inputs['input_ids']):
62
+ inputs['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens
63
+
64
+ inputs = self.tokenizer.pad(
65
+ inputs,
66
+ padding=True,
67
+ return_tensors="pt",
68
+ max_length=self.max_length
69
+ )
70
+
71
  for key in inputs:
72
  inputs[key] = inputs[key].to(self.model.device)
73
  return inputs
74
 
75
+ @torch.no_grad()
76
+ def compute_logits(self, queries, documents, top_k: int = 3):
77
+ """
78
+ Izračunaj reranking skorove i vrati top_k rezultata
79
+
80
+ Args:
81
+ queries: Lista query-ja (obično isti query ponovljen)
82
+ documents: Lista dokumenata za reranking
83
+ top_k: Broj najboljih rezultata (default: 3)
84
+
85
+ Returns:
86
+ Lista tuple-ova: [(score, query, document), ...] sortirano po skoru
87
+ """
88
  task = 'Na osnovu datog upita, vrati najrelevantije rezultate koje odgovaraju upitu'
89
  pairs = [self.format_instruction(task, query, doc) for query, doc in zip(queries, documents)]
90
  inputs = self.process_inputs(pairs)
91
 
92
+ # Izračunaj skorove
93
  batch_scores = self.model(**inputs).logits[:, -1, :]
94
+ true_vector = batch_scores[:, self.token_true_id]
95
+ false_vector = batch_scores[:, self.token_false_id]
96
  batch_scores = torch.stack([false_vector, true_vector], dim=1)
97
  batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
98
  scores = batch_scores[:, 1].exp().tolist()
99
 
100
+ # Sortiraj i vrati top_k
101
  results = list(zip(scores, queries, documents))
102
  results.sort(key=lambda x: x[0], reverse=True)
103
+
104
+ return results[:top_k]