Jac-Zac commited on
Commit
e2cecb1
·
1 Parent(s): 4c8079c

Small update to slim up return

Browse files
Files changed (2) hide show
  1. utils/chat.py +8 -10
  2. uv.lock +3 -3
utils/chat.py CHANGED
@@ -1,5 +1,5 @@
1
  import logging
2
- from contextlib import contextmanager, nullcontext
3
  from dataclasses import dataclass
4
  from typing import Literal
5
 
@@ -174,9 +174,11 @@ def generate_chat_reply(
174
 
175
  generation_kwargs: dict[str, object] = {
176
  "max_new_tokens": max_new_tokens,
177
- "return_dict_in_generate": True,
178
  "use_cache": True,
179
  }
 
 
 
180
  if do_sample:
181
  generation_kwargs["do_sample"] = True
182
  generation_kwargs["temperature"] = temperature
@@ -186,21 +188,17 @@ def generate_chat_reply(
186
  generation_kwargs["repetition_penalty"] = repetition_penalty
187
  if past_key_values is not None and not remote:
188
  generation_kwargs["past_key_values"] = past_key_values
189
- if remote:
190
- generation_kwargs["remote"] = True
191
- # WARNING: NDIF returns caches on CPU, so cross-turn cache reuse is not stable.
192
 
 
 
193
  with _seeded_rng(seed if do_sample and not remote else None):
194
- with model.generate(prompt, **generation_kwargs) as tracer:
195
  generated = tracer.result.save()
196
 
197
  if hasattr(generated, "value") and getattr(generated, "value") is not None:
198
  generated = generated.value
199
 
200
- if not hasattr(generated, "sequences"):
201
- raise ValueError("Generation did not return token sequences")
202
-
203
- sequences = generated.sequences
204
  if not isinstance(sequences, torch.Tensor):
205
  raise TypeError("Generated sequences must be a tensor")
206
 
 
1
  import logging
2
+ from contextlib import contextmanager
3
  from dataclasses import dataclass
4
  from typing import Literal
5
 
 
174
 
175
  generation_kwargs: dict[str, object] = {
176
  "max_new_tokens": max_new_tokens,
 
177
  "use_cache": True,
178
  }
179
+ if not remote:
180
+ # No need for this in remote which also slows down download drastically
181
+ generation_kwargs["return_dict_in_generate"] = True
182
  if do_sample:
183
  generation_kwargs["do_sample"] = True
184
  generation_kwargs["temperature"] = temperature
 
188
  generation_kwargs["repetition_penalty"] = repetition_penalty
189
  if past_key_values is not None and not remote:
190
  generation_kwargs["past_key_values"] = past_key_values
 
 
 
191
 
192
+ # `remote` is captured by nnsight's RemoteableMixin.trace() and is NOT
193
+ # forwarded to the underlying model's generate
194
  with _seeded_rng(seed if do_sample and not remote else None):
195
+ with model.generate(prompt, remote=remote, **generation_kwargs) as tracer:
196
  generated = tracer.result.save()
197
 
198
  if hasattr(generated, "value") and getattr(generated, "value") is not None:
199
  generated = generated.value
200
 
201
+ sequences = generated.sequences if hasattr(generated, "sequences") else generated
 
 
 
202
  if not isinstance(sequences, torch.Tensor):
203
  raise TypeError("Generated sequences must be a tensor")
204
 
uv.lock CHANGED
@@ -1207,7 +1207,7 @@ requires-dist = [
1207
 
1208
  [[package]]
1209
  name = "persona-vectors"
1210
- version = "0.1.1"
1211
  source = { registry = "https://pypi.org/simple" }
1212
  dependencies = [
1213
  { name = "kaleido" },
@@ -1224,9 +1224,9 @@ dependencies = [
1224
  { name = "transformers" },
1225
  { name = "umap-learn" },
1226
  ]
1227
- sdist = { url = "https://files.pythonhosted.org/packages/b5/f8/618d9380a222d9d639bf62f6e5950a7a475687e82daec2fd9b60f2571dac/persona_vectors-0.1.1.tar.gz", hash = "sha256:b60f19aca42b4b2a67a8d3bcb069891a370d53a2d62dfbfe26aebe62645a8ad9", size = 10918, upload-time = "2026-04-09T10:00:37.622Z" }
1228
  wheels = [
1229
- { url = "https://files.pythonhosted.org/packages/0d/48/fa7eb3ca7655af9a798e2a37ba0fb39873e4ddf54e40ba74eb7940a6300d/persona_vectors-0.1.1-py3-none-any.whl", hash = "sha256:fc91a70f71f2c042a159cab78e4e87d1841219ebdf9df679cf2c01d722b59d32", size = 14429, upload-time = "2026-04-09T10:00:38.381Z" },
1230
  ]
1231
 
1232
  [[package]]
 
1207
 
1208
  [[package]]
1209
  name = "persona-vectors"
1210
+ version = "0.1.2"
1211
  source = { registry = "https://pypi.org/simple" }
1212
  dependencies = [
1213
  { name = "kaleido" },
 
1224
  { name = "transformers" },
1225
  { name = "umap-learn" },
1226
  ]
1227
+ sdist = { url = "https://files.pythonhosted.org/packages/86/ce/4bd6a69dd268ddb7eebf57e1d770a706483682c1aac77181502f94787b45/persona_vectors-0.1.2.tar.gz", hash = "sha256:8f14c5839e619e6a5e3902e54d5335bd58e92dd5d4d1559c5aabc5417084aacd", size = 10942, upload-time = "2026-04-09T10:30:10.731Z" }
1228
  wheels = [
1229
+ { url = "https://files.pythonhosted.org/packages/6d/a3/0e033b727f288564c166c9ef15a338119f52dea5fe2886970c34c03d951b/persona_vectors-0.1.2-py3-none-any.whl", hash = "sha256:dda78fbdf0815bc49d4069b16f4b85526a775637b85a14ba1be87dfb5f1f280c", size = 14451, upload-time = "2026-04-09T10:30:09.668Z" },
1230
  ]
1231
 
1232
  [[package]]