File size: 5,875 Bytes
9bb59e4
 
 
 
 
 
 
 
 
 
 
 
 
 
bcd944b
 
 
9bb59e4
 
 
 
 
 
 
 
 
 
 
ded7d4a
229d86a
9bb59e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a264ef
9bb59e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce9000f
 
9bb59e4
ce9000f
9bb59e4
 
bcd944b
 
 
 
 
9bb59e4
bcd944b
9bb59e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
---
license: apache-2.0
base_model:
- Qwen/Qwen3-Embedding-8B
pipeline_tag: sentence-similarity
---

The model of SitEmb-v1.5-Qwen3. 

### Transformer Usage
```python
import torch

from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
from more_itertools import chunked


residual = True
residual_factor = 0.5

tokenizer = AutoTokenizer.from_pretrained(
    "Qwen/Qwen3-Embedding-8B",
    use_fast=True,
    padding_side='left',
)

model = AutoModel.from_pretrained(
    "SituatedEmbedding/SitEmb-v1.5-Qwen3",
    torch_dtype=torch.bfloat16,
    device_map={"": 0},
)

def _pooling(last_hidden_state, attention_mask, pooling, normalize, input_ids=None, match_idx=None):
    if pooling in ['cls', 'first']:
        reps = last_hidden_state[:, 0]
    elif pooling in ['mean', 'avg', 'average']:
        masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
        reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
    elif pooling in ['last', 'eos']:
        left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
        if left_padding:
            reps = last_hidden_state[:, -1]
        else:
            sequence_lengths = attention_mask.sum(dim=1) - 1
            batch_size = last_hidden_state.shape[0]
            reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]
    elif pooling == 'ext':
        if match_idx is None:
            # default mean
            masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
            reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
        else:
            for k in range(input_ids.shape[0]):
                sep_index = input_ids[k].tolist().index(match_idx)
                attention_mask[k][sep_index:] = 0
            masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
            reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
    else:
        raise ValueError(f'unknown pooling method: {pooling}')
    if normalize:
        reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
    return reps


def first_eos_token_pooling(
        last_hidden_states,
        first_eos_position,
        normalize,
):
    batch_size = last_hidden_states.shape[0]
    reps = last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), first_eos_position]
    if normalize:
        reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
    return reps

def encode_query(tokenizer, model, pooling, queries, batch_size, normalize, max_length, residual):
    task = "Given a search query, retrieve relevant chunks from fictions that answer the query"
    sents = []
    for query in queries:
        sents.append(get_detailed_instruct(task, query))

    return encode_passage(tokenizer, model, pooling, sents, batch_size, normalize, max_length)


def encode_passage(tokenizer, model, pooling, passages, batch_size, normalize, max_length, residual=False):
    pas_embs = []
    pas_embs_residual = []
    total = len(passages) // batch_size + (1 if len(passages) % batch_size != 0 else 0)
    with tqdm(total=total) as pbar:
        for sent_b in chunked(passages, batch_size):
            batch_dict = tokenizer(sent_b, max_length=max_length, padding=True, truncation=True,
                                   return_tensors='pt').to(model.device)
            if residual:
                batch_list_dict = tokenizer(sent_b, max_length=max_length, padding=True, truncation=True, )
                input_ids = batch_list_dict['input_ids']
                attention_mask = batch_list_dict['attention_mask']
                max_len = len(input_ids[0])
                input_starts = [max_len - sum(att) for att in attention_mask]
                eos_pos = []
                for ii, it in zip(input_ids, input_starts):
                    pos = ii.index(tokenizer.pad_token_id, it)
                    eos_pos.append(pos)
                eos_pos = torch.tensor(eos_pos).to(model.device)
            else:
                eos_pos = None
            outputs = model(**batch_dict)
            pemb_ = _pooling(outputs.last_hidden_state, batch_dict['attention_mask'], pooling, normalize)
            if residual:
                remb_ = first_eos_token_pooling(outputs.last_hidden_state, eos_pos, normalize)
                pas_embs_residual.append(remb_)
            pas_embs.append(pemb_)
            pbar.update(1)
    pas_embs = torch.cat(pas_embs, dim=0)
    if pas_embs_residual:
        pas_embs_residual = torch.cat(pas_embs_residual, dim=0)
    else:
        pas_embs_residual = None
    return pas_embs, pas_embs_residual

your_query = "Your Query"

query_hidden, _ = encode_query(
    tokenizer, model, pooling_type="eos", queries=[your_query],
    batch_size=8, normalize=True, max_length=8192, residual=residual,
)

passage_affix = "The context in which the chunk is situated is given below. Please encode the chunk by being aware of the context. Context:\n"
your_chunk = "Your Chunk"
your_context = "Your Context"

candidate_hidden, candidate_hidden_residual = encode_passage(
    tokenizer, model, pooling_type="eos", passages=[f"{your_chunk}<|endoftext|>{passage_affix}{your_context}"],
    batch_size=4, normalize=True, max_length=8192, residual=residual,
)

query2candidate = query_hidden @ candidate_hidden.T    # [num_queries, num_candidates]
if candidate_hidden_residual is not None:
    query2candidate_residual = query_hidden @ candidate_hidden_residual.T
    if residual_factor == 1.:
        query2candidate = query2candidate_residual
    elif residual_factor == 0.:
        pass
    else:
        query2candidate = query2candidate * (1. - residual_factor) + query2candidate_residual * residual_factor

print(query2candidate.tolist())
```