File size: 5,956 Bytes
e15cff6 114a8f1 e15cff6 d41868f e15cff6 f342c38 e15cff6 1f2388b e15cff6 1f2388b e15cff6 114a8f1 e15cff6 114a8f1 e15cff6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
---
license: apache-2.0
base_model:
- Qwen/Qwen3-Embedding-8B
pipeline_tag: sentence-similarity
---
The model of SitEmb-v1.5-Qwen3 trained with additional book notes and their corresponding underlined texts.
### Transformer Usage
```python
import torch
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
from more_itertools import chunked
residual = True
residual_factor = 0.5
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-Embedding-8B",
use_fast=True,
padding_side='left',
)
model = AutoModel.from_pretrained(
"SituatedEmbedding/SitEmb-v1.5-Qwen3-note",
torch_dtype=torch.bfloat16,
device_map={"": 0},
)
def _pooling(last_hidden_state, attention_mask, pooling, normalize, input_ids=None, match_idx=None):
if pooling in ['cls', 'first']:
reps = last_hidden_state[:, 0]
elif pooling in ['mean', 'avg', 'average']:
masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
elif pooling in ['last', 'eos']:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
reps = last_hidden_state[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_state.shape[0]
reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]
elif pooling == 'ext':
if match_idx is None:
# default mean
masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
else:
for k in range(input_ids.shape[0]):
sep_index = input_ids[k].tolist().index(match_idx)
attention_mask[k][sep_index:] = 0
masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
else:
raise ValueError(f'unknown pooling method: {pooling}')
if normalize:
reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
return reps
def first_eos_token_pooling(
last_hidden_states,
first_eos_position,
normalize,
):
batch_size = last_hidden_states.shape[0]
reps = last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), first_eos_position]
if normalize:
reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
return reps
def encode_query(tokenizer, model, pooling, queries, batch_size, normalize, max_length, residual):
task = "Given a search query, retrieve relevant chunks from fictions that answer the query"
sents = []
for query in queries:
sents.append(get_detailed_instruct(task, query))
return encode_passage(tokenizer, model, pooling, sents, batch_size, normalize, max_length)
def encode_passage(tokenizer, model, pooling, passages, batch_size, normalize, max_length, residual=False):
pas_embs = []
pas_embs_residual = []
total = len(passages) // batch_size + (1 if len(passages) % batch_size != 0 else 0)
with tqdm(total=total) as pbar:
for sent_b in chunked(passages, batch_size):
batch_dict = tokenizer(sent_b, max_length=max_length, padding=True, truncation=True,
return_tensors='pt').to(model.device)
if residual:
batch_list_dict = tokenizer(sent_b, max_length=max_length, padding=True, truncation=True, )
input_ids = batch_list_dict['input_ids']
attention_mask = batch_list_dict['attention_mask']
max_len = len(input_ids[0])
input_starts = [max_len - sum(att) for att in attention_mask]
eos_pos = []
for ii, it in zip(input_ids, input_starts):
pos = ii.index(tokenizer.pad_token_id, it)
eos_pos.append(pos)
eos_pos = torch.tensor(eos_pos).to(model.device)
else:
eos_pos = None
outputs = model(**batch_dict)
pemb_ = _pooling(outputs.last_hidden_state, batch_dict['attention_mask'], pooling, normalize)
if residual:
remb_ = first_eos_token_pooling(outputs.last_hidden_state, eos_pos, normalize)
pas_embs_residual.append(remb_)
pas_embs.append(pemb_)
pbar.update(1)
pas_embs = torch.cat(pas_embs, dim=0)
if pas_embs_residual:
pas_embs_residual = torch.cat(pas_embs_residual, dim=0)
else:
pas_embs_residual = None
return pas_embs, pas_embs_residual
your_query = "Your Query"
query_hidden, _ = encode_query(
tokenizer, model, pooling_type="eos", queries=[your_query],
batch_size=8, normalize=True, max_length=8192, residual=residual,
)
passage_affix = "The context in which the chunk is situated is given below. Please encode the chunk by being aware of the context. Context:\n"
your_chunk = "Your Chunk"
your_context = "Your Context"
candidate_hidden, candidate_hidden_residual = encode_passage(
tokenizer, model, pooling_type="eos", passages=[f"{your_chunk}<|endoftext|>{passage_affix}{your_context}"],
batch_size=4, normalize=True, max_length=8192, residual=residual,
)
query2candidate = query_hidden @ candidate_hidden.T # [num_queries, num_candidates]
if candidate_hidden_residual is not None:
query2candidate_residual = query_hidden @ candidate_hidden_residual.T
if residual_factor == 1.:
query2candidate = query2candidate_residual
elif residual_factor == 0.:
pass
else:
query2candidate = query2candidate * (1. - residual_factor) + query2candidate_residual * residual_factor
print(query2candidate.tolist())
``` |