File size: 5,956 Bytes
e15cff6
 
 
 
 
 
 
 
 
 
 
 
 
 
114a8f1
 
 
e15cff6
 
 
 
 
 
 
 
 
 
 
 
d41868f
e15cff6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f342c38
e15cff6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f2388b
 
e15cff6
1f2388b
e15cff6
 
114a8f1
 
 
 
 
e15cff6
114a8f1
e15cff6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
---
license: apache-2.0
base_model:
- Qwen/Qwen3-Embedding-8B
pipeline_tag: sentence-similarity
---

The model of SitEmb-v1.5-Qwen3 trained with additional book notes and their corresponding underlined texts. 

### Transformer Usage
```python
import torch

from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
from more_itertools import chunked


residual = True
residual_factor = 0.5

tokenizer = AutoTokenizer.from_pretrained(
    "Qwen/Qwen3-Embedding-8B",
    use_fast=True,
    padding_side='left',
)

model = AutoModel.from_pretrained(
    "SituatedEmbedding/SitEmb-v1.5-Qwen3-note",
    torch_dtype=torch.bfloat16,
    device_map={"": 0},
)

def _pooling(last_hidden_state, attention_mask, pooling, normalize, input_ids=None, match_idx=None):
    if pooling in ['cls', 'first']:
        reps = last_hidden_state[:, 0]
    elif pooling in ['mean', 'avg', 'average']:
        masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
        reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
    elif pooling in ['last', 'eos']:
        left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
        if left_padding:
            reps = last_hidden_state[:, -1]
        else:
            sequence_lengths = attention_mask.sum(dim=1) - 1
            batch_size = last_hidden_state.shape[0]
            reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]
    elif pooling == 'ext':
        if match_idx is None:
            # default mean
            masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
            reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
        else:
            for k in range(input_ids.shape[0]):
                sep_index = input_ids[k].tolist().index(match_idx)
                attention_mask[k][sep_index:] = 0
            masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
            reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
    else:
        raise ValueError(f'unknown pooling method: {pooling}')
    if normalize:
        reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
    return reps


def first_eos_token_pooling(
        last_hidden_states,
        first_eos_position,
        normalize,
):
    batch_size = last_hidden_states.shape[0]
    reps = last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), first_eos_position]
    if normalize:
        reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
    return reps

def encode_query(tokenizer, model, pooling, queries, batch_size, normalize, max_length, residual):
    task = "Given a search query, retrieve relevant chunks from fictions that answer the query"
    sents = []
    for query in queries:
        sents.append(get_detailed_instruct(task, query))

    return encode_passage(tokenizer, model, pooling, sents, batch_size, normalize, max_length)


def encode_passage(tokenizer, model, pooling, passages, batch_size, normalize, max_length, residual=False):
    pas_embs = []
    pas_embs_residual = []
    total = len(passages) // batch_size + (1 if len(passages) % batch_size != 0 else 0)
    with tqdm(total=total) as pbar:
        for sent_b in chunked(passages, batch_size):
            batch_dict = tokenizer(sent_b, max_length=max_length, padding=True, truncation=True,
                                   return_tensors='pt').to(model.device)
            if residual:
                batch_list_dict = tokenizer(sent_b, max_length=max_length, padding=True, truncation=True, )
                input_ids = batch_list_dict['input_ids']
                attention_mask = batch_list_dict['attention_mask']
                max_len = len(input_ids[0])
                input_starts = [max_len - sum(att) for att in attention_mask]
                eos_pos = []
                for ii, it in zip(input_ids, input_starts):
                    pos = ii.index(tokenizer.pad_token_id, it)
                    eos_pos.append(pos)
                eos_pos = torch.tensor(eos_pos).to(model.device)
            else:
                eos_pos = None
            outputs = model(**batch_dict)
            pemb_ = _pooling(outputs.last_hidden_state, batch_dict['attention_mask'], pooling, normalize)
            if residual:
                remb_ = first_eos_token_pooling(outputs.last_hidden_state, eos_pos, normalize)
                pas_embs_residual.append(remb_)
            pas_embs.append(pemb_)
            pbar.update(1)
    pas_embs = torch.cat(pas_embs, dim=0)
    if pas_embs_residual:
        pas_embs_residual = torch.cat(pas_embs_residual, dim=0)
    else:
        pas_embs_residual = None
    return pas_embs, pas_embs_residual

your_query = "Your Query"

query_hidden, _ = encode_query(
    tokenizer, model, pooling_type="eos", queries=[your_query],
    batch_size=8, normalize=True, max_length=8192, residual=residual,
)

passage_affix = "The context in which the chunk is situated is given below. Please encode the chunk by being aware of the context. Context:\n"
your_chunk = "Your Chunk"
your_context = "Your Context"

candidate_hidden, candidate_hidden_residual = encode_passage(
    tokenizer, model, pooling_type="eos", passages=[f"{your_chunk}<|endoftext|>{passage_affix}{your_context}"],
    batch_size=4, normalize=True, max_length=8192, residual=residual,
)

query2candidate = query_hidden @ candidate_hidden.T    # [num_queries, num_candidates]
if candidate_hidden_residual is not None:
    query2candidate_residual = query_hidden @ candidate_hidden_residual.T
    if residual_factor == 1.:
        query2candidate = query2candidate_residual
    elif residual_factor == 0.:
        pass
    else:
        query2candidate = query2candidate * (1. - residual_factor) + query2candidate_residual * residual_factor

print(query2candidate.tolist())
```