maxoul commited on
Commit
f440f7a
·
verified ·
1 Parent(s): eeb94eb

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +151 -0
utils.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from typing import Any
5
+ from transformers import AutoTokenizer
6
+
7
+
8
+ def splade_max(features, attention_mask):
9
+ """
10
+ SPLADE pooling operation
11
+ """
12
+ relu = torch.nn.ReLU(inplace=False)
13
+ values, ids_ = torch.max(
14
+ torch.log(1 + relu(features)) * attention_mask.unsqueeze(-1), dim=1
15
+ )
16
+ return values, ids_
17
+
18
+
19
+ def encode(
20
+ self,
21
+ sentences: list[str],
22
+ max_length: int = 1024,
23
+ prompt_type: str = "document",
24
+ return_dict: bool = False,
25
+ print_dict: bool = False,
26
+ batch_size: int = 8,
27
+ top_k_q: int = -1,
28
+ top_k_d: int = -1,
29
+ **kwargs: Any,
30
+ ) -> np.ndarray:
31
+ all_embeddings = []
32
+ for i in range(0, len(sentences), batch_size):
33
+ batch_texts = sentences[i : i + batch_size]
34
+ batch_dict = self.create_batch_dict(batch_texts, max_length)
35
+ batch_dict = {
36
+ key: value.to(self.model.device) for key, value in batch_dict.items()
37
+ }
38
+ with torch.no_grad():
39
+ splare_reps = self(**batch_dict)[0]
40
+ if prompt_type == "query" and top_k_q > 0:
41
+ splare_reps = top_k(splare_reps, top_k_q)
42
+ if prompt_type == "document" and top_k_d > 0:
43
+ splare_reps = top_k(splare_reps, top_k_d)
44
+ all_embeddings.append(splare_reps.cpu().float().numpy())
45
+ if return_dict:
46
+ d = bow_dict(self, np.concatenate(all_embeddings, axis=0))
47
+ if print_dict:
48
+ print_bow_bars(sentences, d)
49
+ return d
50
+ else:
51
+ return np.concatenate(all_embeddings, axis=0)
52
+
53
+
54
+ def bow_dict(self, embeddings):
55
+ out = []
56
+ for vector in embeddings:
57
+ idx = np.nonzero(vector)[0]
58
+ weights = vector[idx]
59
+ d = {k: v for k, v in zip(idx.tolist(), weights.tolist())}
60
+ sorted_d = {
61
+ self.reverse_voc[k]: float(v)
62
+ for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)
63
+ }
64
+ out.append(sorted_d)
65
+ return out
66
+
67
+
68
+ def print_bow_bars(sentences, bow_list, width=20):
69
+ ascii_header("TOP ACTIVATED WORDS")
70
+ for sent, bow in zip(sentences, bow_list):
71
+ print(f"* INPUT: {sent}\n")
72
+ max_w = max(bow.values())
73
+ for k, v in sorted(bow.items(), key=lambda x: x[1], reverse=True):
74
+ bar = "█" * int(v / max_w * width)
75
+ print(f"{k[:25]:25} | {bar} {v:.2f}")
76
+ print("\n")
77
+
78
+
79
+ def ascii_header(title, width=70):
80
+ title = f" {title} "
81
+ print("+" + "-" * (width - 2) + "+")
82
+ print("|" + title.center(width - 2) + "|")
83
+ print("+" + "-" * (width - 2) + "+")
84
+ print("\n")
85
+
86
+
87
+ def similarity(self, a, b) -> torch.Tensor:
88
+ """
89
+ MTEB eval requires this
90
+ """
91
+ if not isinstance(a, torch.Tensor):
92
+ a = torch.tensor(a)
93
+ if not isinstance(b, torch.Tensor):
94
+ b = torch.tensor(b)
95
+
96
+ def _dot_score_core(a_tensor, b_tensor):
97
+ if len(a_tensor.shape) == 1:
98
+ a_tensor = a_tensor.unsqueeze(0)
99
+ if len(b_tensor.shape) == 1:
100
+ b_tensor = b_tensor.unsqueeze(0)
101
+ return a_tensor @ b_tensor.transpose(0, 1)
102
+
103
+ return _dot_score_core(a, b)
104
+
105
+
106
+ def prepare_tokenizer(tokenizer_name: str, padding_side="right"):
107
+ """
108
+ loads and prepares tokenizer
109
+ """
110
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
111
+ tokenizer.pad_token = (
112
+ tokenizer.bos_token or tokenizer.pad_token or tokenizer.eos_token
113
+ )
114
+ tokenizer.padding_side = padding_side
115
+ return tokenizer
116
+
117
+
118
+ def get_decoder_model(
119
+ model_name_or_path: str, attn_implementation: str, bidirectional: bool, base_cfg
120
+ ):
121
+ """
122
+ base_cfg is the pretrained config of the underlying model
123
+ """
124
+ print("WARNING: bidirectional only tested for transformer 4.51.2")
125
+ assert (
126
+ bidirectional is True
127
+ ), "the model has been trained with bi-directional attention!"
128
+ assert (
129
+ attn_implementation == "flash_attention_2"
130
+ ), f"bidir models only support flash_attention_2 for now, not {attn_implementation}!"
131
+ from .modeling_qwen3_bidir import Qwen3BidirForCausalLM
132
+
133
+ return Qwen3BidirForCausalLM.from_pretrained(
134
+ model_name_or_path,
135
+ config=base_cfg,
136
+ torch_dtype=torch.bfloat16,
137
+ attn_implementation=attn_implementation,
138
+ )
139
+
140
+
141
+ def top_k(x: torch.Tensor, k: int) -> torch.Tensor:
142
+ """
143
+ zeroes out all but the top-k values in the last dimension of x
144
+ """
145
+ _, topk_indices = x.topk(k, dim=-1)
146
+ # create a zero tensor of the same shape as x
147
+ mask = torch.zeros_like(x, dtype=torch.bool)
148
+ # use scatter along the last dimension
149
+ mask.scatter_(-1, topk_indices, True)
150
+ # zero out all but the top-k
151
+ return x * mask