raxtemur commited on
Commit
204e3e7
·
verified ·
1 Parent(s): 5d9a148

Initial upload (weights + code + README)

Browse files
README.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ tags:
5
+ - sonar-llm
6
+ - sonar
7
+ - llama
8
+ - text-generation
9
+ - embeddings
10
+ license: apache-2.0
11
+ library_name: transformers
12
+ pipeline_tag: text-generation
13
+ ---
14
+
15
+ # SONAR-LLM (39M)
16
+
17
+ We present SONAR-LLM, a decoder-only transformer that "thinks" in the same continuous SONAR embedding space, yet is supervised through token-level cross-entropy propagated via the frozen SONAR decoder. This hybrid objective retains the semantic abstraction of LCM while eliminating its diffusion sampler and restoring a likelihood-based training signal. Across model sizes from 39M to 1.3B parameters, SONAR-LLM attains competitive generation quality.
18
+
19
+ Original repository: [FusionBrainLab/SONAR-LLM](https://github.com/FusionBrainLab/SONAR-LLM)
20
+
21
+ Paper: [arXiv:2508.05305](https://arxiv.org/abs/2508.05305)
22
+
23
+ Minimal bundle with SONAR-LLM 39M checkpoint and code.
24
+
25
+ ## Install
26
+ - Use a fresh venv/conda
27
+ - Install SONAR from the official repo: [facebookresearch/SONAR](https://github.com/facebookresearch/SONAR)
28
+ - Ensure PyTorch and transformers are installed
29
+ - (Optional) Download NLTK punkt: `python -c "import nltk; nltk.download('punkt')"`
30
+
31
+ ## Usage
32
+ ```python
33
+ from huggingface_hub import snapshot_download
34
+ import sys
35
+ p = snapshot_download("raxtemur/sonar-llm-39m")
36
+ sys.path.insert(0, p)
37
+
38
+ from sonarllm_model import SONARLLMGenerator, SONARLLMGenerationConfig
39
+
40
+ gen = SONARLLMGenerator.load_from_checkpoint(p)
41
+ eos_emb = gen.t2vec.predict(["End of sequence."], source_lang="eng_Latn").to(gen.device)
42
+ cfg = SONARLLMGenerationConfig(temperature=0)
43
+ print(gen.generate("Once upon a time", eos_emb, cfg))
44
+ ```
45
+
46
+ ## Files
47
+ - `pytorch_model.bin`
48
+ - `config.json`
49
+ - `sonarllm_model/`
50
+
51
+ ## Notes
52
+ - SONAR install guide: [facebookresearch/SONAR](https://github.com/facebookresearch/SONAR)
53
+ - Tokenizer name is taken from `config.json`.
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "pretrained_model_name_or_path": "meta-llama/Llama-3.2-1B",
3
+ "llama_config": {
4
+ "hidden_size": 256,
5
+ "intermediate_size": 1024,
6
+ "num_hidden_layers": 8,
7
+ "num_attention_heads": 16,
8
+ "hidden_act": "silu",
9
+ "max_position_embeddings": 131072,
10
+ "initializer_range": 0.02,
11
+ "rms_norm_eps": 1e-06,
12
+ "use_cache": true,
13
+ "pretraining_tp": 1,
14
+ "tie_word_embeddings": true,
15
+ "rope_theta": 500000.0,
16
+ "rope_scaling": {
17
+ "factor": 32.0,
18
+ "high_freq_factor": 4.0,
19
+ "low_freq_factor": 1.0,
20
+ "original_max_position_embeddings": 8192,
21
+ "rope_type": "llama3"
22
+ },
23
+ "attention_bias": false,
24
+ "attention_dropout": 0.0,
25
+ "mlp_bias": false,
26
+ "head_dim": 16
27
+ },
28
+ "embed_dim": 1024
29
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d751098b27e6c6bc6c2270fa3dec7a32775dc50578524e621bddff48dc54c31
3
+ size 158635746
sonarllm_model/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .sonar_llm_model import SONARLLMGenerator, SONARLLMGenerationConfig
2
+
3
+
4
+
5
+
sonarllm_model/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (323 Bytes). View file
 
sonarllm_model/__pycache__/embedding_to_text_with_scores.cpython-311.pyc ADDED
Binary file (4.54 kB). View file
 
sonarllm_model/__pycache__/sonar_llm_model.cpython-311.pyc ADDED
Binary file (22.7 kB). View file
 
sonarllm_model/embedding_to_text_with_scores.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Iterable, List, Optional
3
+
4
+ import torch
5
+
6
+ from fairseq2.generation import (
7
+ BeamSearchSeq2SeqGenerator,
8
+ Sampler,
9
+ SamplingSeq2SeqGenerator,
10
+ Seq2SeqGenerator,
11
+ SequenceToTextConverter,
12
+ )
13
+
14
+ from sonar.inference_pipelines.utils import add_progress_bar
15
+ from sonar.inference_pipelines.text import (
16
+ EmbeddingToTextModelPipeline as _BaseEmbeddingToTextModelPipeline,
17
+ )
18
+ from fairseq2.data.data_pipeline import read_sequence
19
+
20
+
21
+ class EmbeddingToTextModelPipeline(_BaseEmbeddingToTextModelPipeline):
22
+ """Drop-in replacement that can also return sentence log-probabilities via return_scores.
23
+
24
+ - When return_scores=False (default), behaves exactly like the base pipeline and returns List[str].
25
+ - When return_scores=True, returns a tuple (List[str], List[float]) where each float is the
26
+ hypothesis score from fairseq2 (sum of token log-probabilities if normalize_scores=False,
27
+ otherwise length-normalized per fairseq2 semantics).
28
+ """
29
+
30
+ @torch.inference_mode()
31
+ def predict(
32
+ self,
33
+ inputs: torch.Tensor,
34
+ target_lang: str,
35
+ batch_size: int = 5,
36
+ progress_bar: bool = False,
37
+ sampler: Optional[Sampler] = None,
38
+ return_scores: bool = False,
39
+ **generator_kwargs,
40
+ ):
41
+ if sampler is not None:
42
+ generator: Seq2SeqGenerator = SamplingSeq2SeqGenerator(
43
+ self.model, sampler, **generator_kwargs
44
+ )
45
+ else:
46
+ generator = BeamSearchSeq2SeqGenerator(self.model, **generator_kwargs)
47
+
48
+ converter = SequenceToTextConverter(
49
+ generator,
50
+ self.tokenizer,
51
+ task="translation",
52
+ target_lang=target_lang,
53
+ )
54
+
55
+ def _do_translate(src_tensors: List[torch.Tensor]):
56
+ texts, gen_out = converter.batch_convert(
57
+ torch.stack(src_tensors).to(self.device), None
58
+ )
59
+ if return_scores:
60
+ scores: List[float] = []
61
+ for hyps in gen_out.hypotheses:
62
+ if len(hyps) == 0 or hyps[0].score is None:
63
+ scores.append(0.0)
64
+ else:
65
+ scores.append(float(hyps[0].score))
66
+ return texts, scores
67
+ return texts
68
+
69
+ pipeline: Iterable = (
70
+ read_sequence(list(inputs))
71
+ .bucket(batch_size)
72
+ .map(_do_translate)
73
+ .and_return()
74
+ )
75
+
76
+ if progress_bar:
77
+ pipeline = add_progress_bar(pipeline, inputs=inputs, batch_size=batch_size)
78
+
79
+ results: List = list(iter(pipeline))
80
+
81
+ if not return_scores:
82
+ # results is List[List[str]] → flatten
83
+ return [text for batch_texts in results for text in batch_texts]
84
+
85
+ # results is List[Tuple[List[str], List[float]]] → flatten both
86
+ all_texts: List[str] = []
87
+ all_scores: List[float] = []
88
+ for batch in results:
89
+ batch_texts, batch_scores = batch
90
+ all_texts.extend(batch_texts)
91
+ all_scores.extend(batch_scores)
92
+ return all_texts, all_scores
93
+
94
+
sonarllm_model/sonar_llm_model.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple
4
+
5
+
6
+ import nltk
7
+ from nltk.tokenize import sent_tokenize
8
+ nltk.download("punkt", quiet=True)
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from .embedding_to_text_with_scores import EmbeddingToTextModelPipeline
15
+ from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
16
+
17
+ class Projector(nn.Module):
18
+ def __init__(self, in_dim: int, out_dim: int):
19
+ super().__init__()
20
+ self.linear = nn.Linear(in_dim, out_dim)
21
+
22
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
23
+ return self.linear(x)
24
+
25
+ @dataclass
26
+ class SONARLLMGenerationConfig:
27
+ # Outer sentence-level beam
28
+ sentence_beam_size: int = 4
29
+ latent_samples_per_step: int = 4 # M latent variants per active beam state
30
+
31
+ # Token-level decoder params
32
+ decoder_beam_size: int = 5 # default in fairseq2
33
+ decoder_temperature: float = 1.0 # default in fairseq2
34
+ normalize_sentence_scores: bool = True # False → sum of token log-probs
35
+ decoder_max_len: int = 256
36
+
37
+ # Latent sampling
38
+ temperature: float = 0.4
39
+ latent_top_p: Optional[float] = None # 0<p<=1 or None for Gaussian
40
+ temperature_mode: str = "relative" # "absolute" | "relative"
41
+
42
+ # Repetition control in latent space
43
+ repetition_penalty: float = 0.0
44
+ repetition_memory: int = 0
45
+
46
+ # Termination
47
+ max_sentences: int = 32
48
+ eos_threshold: float = 0.98
49
+
50
+
51
+ class SONARLLMGenerator(torch.nn.Module):
52
+ """Sentence-level beam over latent reversed embeddings using SONAR decoder.
53
+
54
+ For each step:
55
+ - Run LLaMA on the sentence embedding history to get final hidden.
56
+ - Sample multiple latent directions (temperature/latent_top_p, with repetition penalty).
57
+ - Project to `reversed_emb` and decode text via SONAR decoder.
58
+ - Score each candidate using decoder sentence logprob (+ optional shaping).
59
+ - Keep top `sentence_beam_size` states and continue until EOS or max sentences.
60
+
61
+ This class does NOT modify existing project files and can be used standalone.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ llama_model: nn.Module,
67
+ forward_proj: nn.Module,
68
+ reverse_proj: nn.Module,
69
+ sonar_decoder: EmbeddingToTextModelPipeline,
70
+ t2vec_model: TextToEmbeddingModelPipeline,
71
+ device: torch.device,
72
+ add_begin: bool = False,
73
+ ) -> None:
74
+ super().__init__()
75
+ self.llama_model = llama_model.eval()
76
+ self.forward_proj = forward_proj.eval()
77
+ self.reverse_proj = reverse_proj.eval()
78
+ self.sonar_decoder = sonar_decoder.eval()
79
+ self.t2vec = t2vec_model.eval()
80
+ self.device = device
81
+ self.add_begin = add_begin
82
+
83
+ @torch.no_grad()
84
+ def generate(self, prefix_text: str, eos_emb: torch.Tensor, cfg: Optional[SONARLLMGenerationConfig] = None) -> str:
85
+ # Normalize and attach config to the instance for helper use
86
+ if cfg is None:
87
+ cfg = SONARLLMGenerationConfig()
88
+ self._cfg = cfg
89
+ sents = sent_tokenize(prefix_text)
90
+ if self.add_begin:
91
+ sents = ["Begin of text."] + sents
92
+
93
+ if len(sents) == 0:
94
+ sents = [prefix_text.strip()]
95
+
96
+ # Initialize prefix embeddings
97
+ emb_seq = self.t2vec.predict(sents, source_lang="eng_Latn").to(self.device)
98
+
99
+ # Beam state tuple: (sentences, embeddings_seq, cumulative_score, recent_dirs)
100
+ beams: List[Tuple[List[str], torch.Tensor, float, List[torch.Tensor]]] = [
101
+ (sents[:], emb_seq, 0.0, [])
102
+ ]
103
+
104
+ steps = 0
105
+ while steps < self._cfg.max_sentences:
106
+ steps += 1
107
+ candidates: List[Tuple[List[str], torch.Tensor, float, List[torch.Tensor]]] = []
108
+
109
+ for (hist_sents, hist_emb, score, recent_dirs) in beams:
110
+ candidates.extend(
111
+ self._expand_beam_state(hist_sents, hist_emb, score, recent_dirs, eos_emb)
112
+ )
113
+
114
+ # Keep top-k beams
115
+ if len(candidates) == 0:
116
+ break
117
+ candidates.sort(key=lambda b: b[2], reverse=True)
118
+ beams = candidates[: int(self._cfg.sentence_beam_size)]
119
+
120
+ # If all beams look ended by EOS threshold, stop early
121
+ if self._all_close_to_eos(beams, eos_emb):
122
+ break
123
+
124
+ best = max(beams, key=lambda b: b[2])
125
+ result = self._join_sentences(best[0])
126
+ if self.add_begin:
127
+ result = result[len("Begin of text."):]
128
+ return result
129
+
130
+ # --- internals ---
131
+
132
+ @torch.no_grad()
133
+ def _forward_hidden(self, emb_seq: torch.Tensor) -> torch.Tensor:
134
+ proj = self.forward_proj(emb_seq.unsqueeze(0)) if emb_seq.ndim == 2 else self.forward_proj(emb_seq)
135
+ out = self.llama_model(inputs_embeds=proj, output_hidden_states=True)
136
+ hidden = out.hidden_states[-1]
137
+ return hidden[0, -1, :]
138
+
139
+ def _join_sentences(self, sents: List[str]) -> str:
140
+ return " ".join(sents)
141
+
142
+ def _update_recent_dirs(
143
+ self, recent: List[torch.Tensor], u: torch.Tensor, memory_cap: int
144
+ ) -> List[torch.Tensor]:
145
+ if memory_cap <= 0:
146
+ return recent
147
+ if not torch.isfinite(u).all():
148
+ return recent
149
+ new_recent = recent + [u.detach().to("cpu")]
150
+ if len(new_recent) > int(memory_cap):
151
+ new_recent = new_recent[-int(memory_cap) :]
152
+ return new_recent
153
+
154
+ def _sample_noise_direction(
155
+ self, final_hidden: torch.Tensor, recent_dirs: List[torch.Tensor]
156
+ ) -> torch.Tensor:
157
+ g = torch.randn_like(final_hidden)
158
+ if (
159
+ self._cfg.repetition_penalty is not None
160
+ and float(self._cfg.repetition_penalty) != 1.0
161
+ and self._cfg.repetition_memory > 0
162
+ and len(recent_dirs) > 0
163
+ ):
164
+ g = self._apply_repetition_penalty_to_direction(
165
+ g, float(self._cfg.repetition_penalty), int(self._cfg.repetition_memory), recent_dirs
166
+ )
167
+ return g / (g.norm(p=2) + 1e-12)
168
+
169
+ def _sample_noise(
170
+ self, final_hidden: torch.Tensor, dir_unit: torch.Tensor
171
+ ) -> torch.Tensor:
172
+ t = float(self._cfg.temperature)
173
+ if t <= 0.0:
174
+ return torch.zeros_like(final_hidden)
175
+
176
+ if self._cfg.temperature_mode not in ("absolute", "relative"):
177
+ raise ValueError(f"Unsupported temperature_mode: {self._cfg.temperature_mode}")
178
+
179
+ if self._cfg.temperature_mode == "absolute":
180
+ sigma = torch.tensor(t, device=final_hidden.device, dtype=final_hidden.dtype)
181
+ else:
182
+ rms = torch.sqrt(torch.mean(final_hidden.to(torch.float32) ** 2))
183
+ rms = torch.clamp(rms, min=1e-12).to(dtype=final_hidden.dtype, device=final_hidden.device)
184
+ sigma = rms * t
185
+
186
+ top_p = self._cfg.latent_top_p
187
+ if top_p is None:
188
+ top_p = 1.0
189
+ return self._sample_truncated_normal_like(final_hidden, float(top_p), sigma, dir_unit)
190
+
191
+ def _sample_truncated_normal_like(
192
+ self, base_vector: torch.Tensor, top_p: float, sigma: torch.Tensor, dir_unit: torch.Tensor
193
+ ) -> torch.Tensor:
194
+ # Wilson–Hilferty approximation for ChiSquare quantiles
195
+ dim = base_vector.numel()
196
+ device = base_vector.device
197
+ u = torch.rand((), device=device, dtype=torch.float32)
198
+ p = torch.clamp(u * float(top_p), min=1e-12, max=1.0 - 1e-12)
199
+ k = torch.tensor(float(dim), device=device, dtype=torch.float32)
200
+ z = torch.sqrt(torch.tensor(2.0, device=device, dtype=torch.float32)) * torch.special.erfinv(2.0 * p - 1.0)
201
+ term = 1.0 - 2.0 / (9.0 * k) + z * torch.sqrt(2.0 / (9.0 * k))
202
+ term = torch.clamp(term, min=1e-12)
203
+ s = k * (term ** 3)
204
+ r = torch.sqrt(torch.clamp(s, min=1e-12)).to(dtype=base_vector.dtype)
205
+ return dir_unit * (r * sigma)
206
+
207
+ def _expand_beam_state(
208
+ self,
209
+ hist_sents: List[str],
210
+ hist_emb: torch.Tensor,
211
+ score: float,
212
+ recent_dirs: List[torch.Tensor],
213
+ eos_emb: torch.Tensor,
214
+ ) -> List[Tuple[List[str], torch.Tensor, float, List[torch.Tensor]]]:
215
+ """Expand one beam state into candidate next states.
216
+
217
+ Returns a list of (new_hist_sents, new_hist_emb, new_score, new_recent_dirs).
218
+ """
219
+ final_hidden = self._forward_hidden(hist_emb)
220
+ out: List[Tuple[List[str], torch.Tensor, float, List[torch.Tensor]]] = []
221
+
222
+ for _ in range(max(1, int(self._cfg.latent_samples_per_step))):
223
+ dir_unit = self._sample_noise_direction(final_hidden, recent_dirs)
224
+ noise = self._sample_noise(final_hidden, dir_unit)
225
+ h_perturbed = final_hidden + noise
226
+ z = self.reverse_proj(h_perturbed.unsqueeze(0))
227
+
228
+ texts, scores = self.sonar_decoder.predict(
229
+ z,
230
+ target_lang="eng_Latn",
231
+ beam_size=int(self._cfg.decoder_beam_size),
232
+ normalize_scores=bool(self._cfg.normalize_sentence_scores),
233
+ max_seq_len=self._cfg.decoder_max_len,
234
+ temperature=float(self._cfg.decoder_temperature),
235
+ return_scores=True,
236
+ )
237
+ text = texts[0]
238
+ sent_logprob = float(scores[0])
239
+
240
+ z_re = self.t2vec.predict([text], source_lang="eng_Latn").to(self.device)
241
+
242
+ cand_score = score + sent_logprob
243
+ new_recent = self._update_recent_dirs(recent_dirs, dir_unit, self._cfg.repetition_memory)
244
+
245
+ new_hist_sents = hist_sents + [text]
246
+ new_hist_emb = torch.cat([hist_emb, z_re], dim=0)
247
+
248
+ out.append((new_hist_sents, new_hist_emb, cand_score, new_recent))
249
+
250
+ return out
251
+
252
+ def _apply_repetition_penalty_to_direction(
253
+ self, g: torch.Tensor, penalty: float, memory_cap: int, recent_dirs: List[torch.Tensor]
254
+ ) -> torch.Tensor:
255
+ """Mean-shift (A+) repetition penalty in latent direction space.
256
+
257
+ - penalty is clamped to [0, 1].
258
+ - penalty = 0 → no shift (q = 0.5).
259
+ - penalty = 1 → maximum shift (q ≈ q_min).
260
+ Mapping: q = 0.5^(1-penalty) * q_min^(penalty), beta = Phi^{-1}(1 - q),
261
+ and we set g' = g - beta * b_unit, where b_unit is the normalized average of recent directions.
262
+ """
263
+ if memory_cap <= 0 or len(recent_dirs) == 0:
264
+ return g
265
+
266
+ # Aggregate and normalize recent directions
267
+ B = torch.stack(
268
+ [u.to(device=g.device, dtype=g.dtype) for u in recent_dirs[-int(memory_cap):]], dim=0
269
+ )
270
+ b = B.mean(dim=0)
271
+ bn = b.norm(p=2)
272
+ if not torch.isfinite(bn) or bn <= 1e-12:
273
+ return g
274
+ b_unit = b / bn
275
+
276
+ # Clamp and map penalty → beta via q
277
+ rp = float(penalty)
278
+ if rp < 0.0:
279
+ rp = 0.0
280
+ if rp > 1.0:
281
+ rp = 1.0
282
+ q_min = 1e-12
283
+ log_q = (1.0 - rp) * torch.log(torch.tensor(0.5, device=g.device, dtype=torch.float32))
284
+ log_q = log_q + rp * torch.log(torch.tensor(q_min, device=g.device, dtype=torch.float32))
285
+ q = torch.exp(log_q)
286
+ p = torch.clamp(1.0 - q, 1e-12, 1.0 - 1e-12)
287
+ beta = torch.sqrt(torch.tensor(2.0, device=g.device, dtype=g.dtype)) * torch.special.erfinv(2.0 * p - 1.0)
288
+ beta = torch.clamp(beta, 0.0, 7.5)
289
+ return g - (beta * b_unit)
290
+
291
+ def _all_close_to_eos(self, beams, eos_emb: torch.Tensor) -> bool:
292
+ for (_, emb, _, _) in beams:
293
+ last = emb[-1:, :]
294
+ sim = F.cosine_similarity(last, eos_emb, dim=1).item()
295
+ if sim < float(self._cfg.eos_threshold):
296
+ return False
297
+ return True
298
+
299
+ # --- factory ---
300
+ @classmethod
301
+ def load_from_checkpoint(
302
+ cls,
303
+ checkpoint_dir: str,
304
+ device: Optional[torch.device] = None,
305
+ generation_config: Optional[SONARLLMGenerationConfig] = None,
306
+ ) -> "SONARLLMGenerator":
307
+ """Load generator from a folder with config.json and weights.
308
+
309
+ The folder is expected to contain:
310
+ - config.json (with keys: pretrained_model_name_or_path, llama_config?, embed_dim)
311
+ - pytorch_model.bin (or model_state_dict inside the saved file)
312
+ """
313
+ import json
314
+ import os
315
+ from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
316
+ from .embedding_to_text_with_scores import EmbeddingToTextModelPipeline
317
+ from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
318
+
319
+ if device is None:
320
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
321
+
322
+ cfg_path = os.path.join(checkpoint_dir, "config.json")
323
+ with open(cfg_path, "r", encoding="utf-8") as f:
324
+ cfg = json.load(f)
325
+
326
+ tokenizer = AutoTokenizer.from_pretrained(cfg["pretrained_model_name_or_path"])
327
+ tokenizer.pad_token = tokenizer.eos_token
328
+
329
+ llama_cfg_dict = cfg.get("llama_config", {})
330
+ llama_cfg_dict["vocab_size"] = len(tokenizer)
331
+ llama_cfg_dict["pad_token_id"] = tokenizer.pad_token_id
332
+ llama_cfg_dict["bos_token_id"] = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 128000
333
+ llama_cfg_dict["eos_token_id"] = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 128001
334
+ llama_cfg = LlamaConfig(**llama_cfg_dict) if "llama_config" in cfg else LlamaConfig()
335
+
336
+ llama_model = LlamaForCausalLM(llama_cfg).to(device).eval()
337
+
338
+ hidden_size = llama_cfg.hidden_size
339
+ embed_dim = cfg.get("embed_dim", 1024)
340
+
341
+ t2vec_model = TextToEmbeddingModelPipeline(
342
+ encoder="text_sonar_basic_encoder",
343
+ tokenizer="text_sonar_basic_encoder",
344
+ device=device,
345
+ ).eval()
346
+
347
+ vec2text_model = EmbeddingToTextModelPipeline(
348
+ decoder="text_sonar_basic_decoder",
349
+ tokenizer="text_sonar_basic_encoder",
350
+ device=device,
351
+ ).eval()
352
+
353
+ forward_projector = Projector(embed_dim, hidden_size).to(device).eval()
354
+ reverse_projector = Projector(hidden_size, embed_dim).to(device).eval()
355
+
356
+ gen = cls(
357
+ llama_model,
358
+ forward_projector,
359
+ reverse_projector,
360
+ vec2text_model,
361
+ t2vec_model,
362
+ device,
363
+ add_begin=cfg.get("add_begin", False),
364
+ )
365
+
366
+ # Load weights into generator to cover llama + projectors
367
+ ckpt_bin = os.path.join(checkpoint_dir, "pytorch_model.bin")
368
+ state = torch.load(ckpt_bin, map_location=device, weights_only=True)
369
+ state = state.get("model_state_dict", state)
370
+ raw = gen.module if hasattr(gen, "module") else gen
371
+ raw.load_state_dict(state, strict=False)
372
+
373
+ return gen
374
+
375
+