File size: 5,684 Bytes
2def23d aeb17ad 2def23d 3def584 2def23d aeb17ad 6064064 2def23d 00239f6 2def23d 3def584 2def23d 3def584 2def23d 3def584 2def23d 3def584 2def23d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | ---
license: apache-2.0
base_model:
- Qwen/Qwen3-Embedding-8B
pipeline_tag: sentence-similarity
---
The model of trained Qwen3 only processing chunk.
### Transformer Usage
```python
import torch
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
from more_itertools import chunked
residual = False
residual_factor = 0.5
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-Embedding-8B",
use_fast=True,
padding_side='left',
)
model = AutoModel.from_pretrained(
"SituatedEmbedding/SitEmb-v1.5-Qwen3-chunk-only",
torch_dtype=torch.bfloat16,
device_map={"": 0},
)
def _pooling(last_hidden_state, attention_mask, pooling, normalize, input_ids=None, match_idx=None):
if pooling in ['cls', 'first']:
reps = last_hidden_state[:, 0]
elif pooling in ['mean', 'avg', 'average']:
masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
elif pooling in ['last', 'eos']:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
reps = last_hidden_state[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_state.shape[0]
reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]
elif pooling == 'ext':
if match_idx is None:
# default mean
masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
else:
for k in range(input_ids.shape[0]):
sep_index = input_ids[k].tolist().index(match_idx)
attention_mask[k][sep_index:] = 0
masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
else:
raise ValueError(f'unknown pooling method: {pooling}')
if normalize:
reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
return reps
def first_eos_token_pooling(
last_hidden_states,
first_eos_position,
normalize,
):
batch_size = last_hidden_states.shape[0]
reps = last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), first_eos_position]
if normalize:
reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
return reps
def encode_query(tokenizer, model, pooling, queries, batch_size, normalize, max_length, residual):
task = "Given a search query, retrieve relevant chunks from fictions that answer the query"
sents = []
for query in queries:
sents.append(get_detailed_instruct(task, query))
return encode_passage(tokenizer, model, pooling, sents, batch_size, normalize, max_length)
def encode_passage(tokenizer, model, pooling, passages, batch_size, normalize, max_length, residual=False):
pas_embs = []
pas_embs_residual = []
total = len(passages) // batch_size + (1 if len(passages) % batch_size != 0 else 0)
with tqdm(total=total) as pbar:
for sent_b in chunked(passages, batch_size):
batch_dict = tokenizer(sent_b, max_length=max_length, padding=True, truncation=True,
return_tensors='pt').to(model.device)
if residual:
batch_list_dict = tokenizer(sent_b, max_length=max_length, padding=True, truncation=True, )
input_ids = batch_list_dict['input_ids']
attention_mask = batch_list_dict['attention_mask']
max_len = len(input_ids[0])
input_starts = [max_len - sum(att) for att in attention_mask]
eos_pos = []
for ii, it in zip(input_ids, input_starts):
pos = ii.index(tokenizer.pad_token_id, it)
eos_pos.append(pos)
eos_pos = torch.tensor(eos_pos).to(model.device)
else:
eos_pos = None
outputs = model(**batch_dict)
pemb_ = _pooling(outputs.last_hidden_state, batch_dict['attention_mask'], pooling, normalize)
if residual:
remb_ = first_eos_token_pooling(outputs.last_hidden_state, eos_pos, normalize)
pas_embs_residual.append(remb_)
pas_embs.append(pemb_)
pbar.update(1)
pas_embs = torch.cat(pas_embs, dim=0)
if pas_embs_residual:
pas_embs_residual = torch.cat(pas_embs_residual, dim=0)
else:
pas_embs_residual = None
return pas_embs, pas_embs_residual
your_query = "Your Query"
query_hidden, _ = encode_query(
tokenizer, model, pooling_type="eos", queries=[your_query],
batch_size=8, normalize=True, max_length=8192, residual=residual,
)
your_chunk = "Your Chunk"
candidate_hidden, candidate_hidden_residual = encode_passage(
tokenizer, model, pooling_type="eos", passages=[your_chunk],
batch_size=4, normalize=True, max_length=8192, residual=residual,
)
query2candidate = query_hidden @ candidate_hidden.T # [num_queries, num_candidates]
if candidate_hidden_residual is not None:
query2candidate_residual = query_hidden @ candidate_hidden_residual.T
if residual_factor == 1.:
query2candidate = query2candidate_residual
elif residual_factor == 0.:
pass
else:
query2candidate = query2candidate * (1. - residual_factor) + query2candidate_residual * residual_factor
print(query2candidate.tolist())
``` |