File size: 5,684 Bytes
2def23d
 
 
 
 
 
 
aeb17ad
2def23d
 
 
 
 
 
3def584
 
2def23d
 
 
 
 
 
 
 
 
 
 
aeb17ad
6064064
2def23d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00239f6
2def23d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3def584
 
2def23d
3def584
2def23d
 
3def584
 
 
2def23d
3def584
2def23d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
---
license: apache-2.0
base_model:
- Qwen/Qwen3-Embedding-8B
pipeline_tag: sentence-similarity
---

The model of trained Qwen3 only processing chunk. 

### Transformer Usage
```python
import torch

from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
from more_itertools import chunked

residual = False
residual_factor = 0.5

tokenizer = AutoTokenizer.from_pretrained(
    "Qwen/Qwen3-Embedding-8B",
    use_fast=True,
    padding_side='left',
)

model = AutoModel.from_pretrained(
    "SituatedEmbedding/SitEmb-v1.5-Qwen3-chunk-only",
    torch_dtype=torch.bfloat16,
    device_map={"": 0},
)

def _pooling(last_hidden_state, attention_mask, pooling, normalize, input_ids=None, match_idx=None):
    if pooling in ['cls', 'first']:
        reps = last_hidden_state[:, 0]
    elif pooling in ['mean', 'avg', 'average']:
        masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
        reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
    elif pooling in ['last', 'eos']:
        left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
        if left_padding:
            reps = last_hidden_state[:, -1]
        else:
            sequence_lengths = attention_mask.sum(dim=1) - 1
            batch_size = last_hidden_state.shape[0]
            reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]
    elif pooling == 'ext':
        if match_idx is None:
            # default mean
            masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
            reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
        else:
            for k in range(input_ids.shape[0]):
                sep_index = input_ids[k].tolist().index(match_idx)
                attention_mask[k][sep_index:] = 0
            masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
            reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
    else:
        raise ValueError(f'unknown pooling method: {pooling}')
    if normalize:
        reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
    return reps


def first_eos_token_pooling(
        last_hidden_states,
        first_eos_position,
        normalize,
):
    batch_size = last_hidden_states.shape[0]
    reps = last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), first_eos_position]
    if normalize:
        reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
    return reps

def encode_query(tokenizer, model, pooling, queries, batch_size, normalize, max_length, residual):
    task = "Given a search query, retrieve relevant chunks from fictions that answer the query"
    sents = []
    for query in queries:
        sents.append(get_detailed_instruct(task, query))

    return encode_passage(tokenizer, model, pooling, sents, batch_size, normalize, max_length)


def encode_passage(tokenizer, model, pooling, passages, batch_size, normalize, max_length, residual=False):
    pas_embs = []
    pas_embs_residual = []
    total = len(passages) // batch_size + (1 if len(passages) % batch_size != 0 else 0)
    with tqdm(total=total) as pbar:
        for sent_b in chunked(passages, batch_size):
            batch_dict = tokenizer(sent_b, max_length=max_length, padding=True, truncation=True,
                                   return_tensors='pt').to(model.device)
            if residual:
                batch_list_dict = tokenizer(sent_b, max_length=max_length, padding=True, truncation=True, )
                input_ids = batch_list_dict['input_ids']
                attention_mask = batch_list_dict['attention_mask']
                max_len = len(input_ids[0])
                input_starts = [max_len - sum(att) for att in attention_mask]
                eos_pos = []
                for ii, it in zip(input_ids, input_starts):
                    pos = ii.index(tokenizer.pad_token_id, it)
                    eos_pos.append(pos)
                eos_pos = torch.tensor(eos_pos).to(model.device)
            else:
                eos_pos = None
            outputs = model(**batch_dict)
            pemb_ = _pooling(outputs.last_hidden_state, batch_dict['attention_mask'], pooling, normalize)
            if residual:
                remb_ = first_eos_token_pooling(outputs.last_hidden_state, eos_pos, normalize)
                pas_embs_residual.append(remb_)
            pas_embs.append(pemb_)
            pbar.update(1)
    pas_embs = torch.cat(pas_embs, dim=0)
    if pas_embs_residual:
        pas_embs_residual = torch.cat(pas_embs_residual, dim=0)
    else:
        pas_embs_residual = None
    return pas_embs, pas_embs_residual

your_query = "Your Query"

query_hidden, _ = encode_query(
    tokenizer, model, pooling_type="eos", queries=[your_query],
    batch_size=8, normalize=True, max_length=8192, residual=residual,
)

your_chunk = "Your Chunk"

candidate_hidden, candidate_hidden_residual = encode_passage(
    tokenizer, model, pooling_type="eos", passages=[your_chunk],
    batch_size=4, normalize=True, max_length=8192, residual=residual,
)

query2candidate = query_hidden @ candidate_hidden.T    # [num_queries, num_candidates]
if candidate_hidden_residual is not None:
    query2candidate_residual = query_hidden @ candidate_hidden_residual.T
    if residual_factor == 1.:
        query2candidate = query2candidate_residual
    elif residual_factor == 0.:
        pass
    else:
        query2candidate = query2candidate * (1. - residual_factor) + query2candidate_residual * residual_factor

print(query2candidate.tolist())
```