File size: 5,875 Bytes
9bb59e4 bcd944b 9bb59e4 ded7d4a 229d86a 9bb59e4 1a264ef 9bb59e4 ce9000f 9bb59e4 ce9000f 9bb59e4 bcd944b 9bb59e4 bcd944b 9bb59e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
---
license: apache-2.0
base_model:
- Qwen/Qwen3-Embedding-8B
pipeline_tag: sentence-similarity
---
The model of SitEmb-v1.5-Qwen3.
### Transformer Usage
```python
import torch
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
from more_itertools import chunked
residual = True
residual_factor = 0.5
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-Embedding-8B",
use_fast=True,
padding_side='left',
)
model = AutoModel.from_pretrained(
"SituatedEmbedding/SitEmb-v1.5-Qwen3",
torch_dtype=torch.bfloat16,
device_map={"": 0},
)
def _pooling(last_hidden_state, attention_mask, pooling, normalize, input_ids=None, match_idx=None):
if pooling in ['cls', 'first']:
reps = last_hidden_state[:, 0]
elif pooling in ['mean', 'avg', 'average']:
masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
elif pooling in ['last', 'eos']:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
reps = last_hidden_state[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_state.shape[0]
reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]
elif pooling == 'ext':
if match_idx is None:
# default mean
masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
else:
for k in range(input_ids.shape[0]):
sep_index = input_ids[k].tolist().index(match_idx)
attention_mask[k][sep_index:] = 0
masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
else:
raise ValueError(f'unknown pooling method: {pooling}')
if normalize:
reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
return reps
def first_eos_token_pooling(
last_hidden_states,
first_eos_position,
normalize,
):
batch_size = last_hidden_states.shape[0]
reps = last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), first_eos_position]
if normalize:
reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
return reps
def encode_query(tokenizer, model, pooling, queries, batch_size, normalize, max_length, residual):
task = "Given a search query, retrieve relevant chunks from fictions that answer the query"
sents = []
for query in queries:
sents.append(get_detailed_instruct(task, query))
return encode_passage(tokenizer, model, pooling, sents, batch_size, normalize, max_length)
def encode_passage(tokenizer, model, pooling, passages, batch_size, normalize, max_length, residual=False):
pas_embs = []
pas_embs_residual = []
total = len(passages) // batch_size + (1 if len(passages) % batch_size != 0 else 0)
with tqdm(total=total) as pbar:
for sent_b in chunked(passages, batch_size):
batch_dict = tokenizer(sent_b, max_length=max_length, padding=True, truncation=True,
return_tensors='pt').to(model.device)
if residual:
batch_list_dict = tokenizer(sent_b, max_length=max_length, padding=True, truncation=True, )
input_ids = batch_list_dict['input_ids']
attention_mask = batch_list_dict['attention_mask']
max_len = len(input_ids[0])
input_starts = [max_len - sum(att) for att in attention_mask]
eos_pos = []
for ii, it in zip(input_ids, input_starts):
pos = ii.index(tokenizer.pad_token_id, it)
eos_pos.append(pos)
eos_pos = torch.tensor(eos_pos).to(model.device)
else:
eos_pos = None
outputs = model(**batch_dict)
pemb_ = _pooling(outputs.last_hidden_state, batch_dict['attention_mask'], pooling, normalize)
if residual:
remb_ = first_eos_token_pooling(outputs.last_hidden_state, eos_pos, normalize)
pas_embs_residual.append(remb_)
pas_embs.append(pemb_)
pbar.update(1)
pas_embs = torch.cat(pas_embs, dim=0)
if pas_embs_residual:
pas_embs_residual = torch.cat(pas_embs_residual, dim=0)
else:
pas_embs_residual = None
return pas_embs, pas_embs_residual
your_query = "Your Query"
query_hidden, _ = encode_query(
tokenizer, model, pooling_type="eos", queries=[your_query],
batch_size=8, normalize=True, max_length=8192, residual=residual,
)
passage_affix = "The context in which the chunk is situated is given below. Please encode the chunk by being aware of the context. Context:\n"
your_chunk = "Your Chunk"
your_context = "Your Context"
candidate_hidden, candidate_hidden_residual = encode_passage(
tokenizer, model, pooling_type="eos", passages=[f"{your_chunk}<|endoftext|>{passage_affix}{your_context}"],
batch_size=4, normalize=True, max_length=8192, residual=residual,
)
query2candidate = query_hidden @ candidate_hidden.T # [num_queries, num_candidates]
if candidate_hidden_residual is not None:
query2candidate_residual = query_hidden @ candidate_hidden_residual.T
if residual_factor == 1.:
query2candidate = query2candidate_residual
elif residual_factor == 0.:
pass
else:
query2candidate = query2candidate * (1. - residual_factor) + query2candidate_residual * residual_factor
print(query2candidate.tolist())
``` |