Jac-Zac commited on
Commit ·
e2cecb1
1
Parent(s): 4c8079c
Small update to slim up return
Browse files- utils/chat.py +8 -10
- uv.lock +3 -3
utils/chat.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import logging
|
| 2 |
-
from contextlib import contextmanager
|
| 3 |
from dataclasses import dataclass
|
| 4 |
from typing import Literal
|
| 5 |
|
|
@@ -174,9 +174,11 @@ def generate_chat_reply(
|
|
| 174 |
|
| 175 |
generation_kwargs: dict[str, object] = {
|
| 176 |
"max_new_tokens": max_new_tokens,
|
| 177 |
-
"return_dict_in_generate": True,
|
| 178 |
"use_cache": True,
|
| 179 |
}
|
|
|
|
|
|
|
|
|
|
| 180 |
if do_sample:
|
| 181 |
generation_kwargs["do_sample"] = True
|
| 182 |
generation_kwargs["temperature"] = temperature
|
|
@@ -186,21 +188,17 @@ def generate_chat_reply(
|
|
| 186 |
generation_kwargs["repetition_penalty"] = repetition_penalty
|
| 187 |
if past_key_values is not None and not remote:
|
| 188 |
generation_kwargs["past_key_values"] = past_key_values
|
| 189 |
-
if remote:
|
| 190 |
-
generation_kwargs["remote"] = True
|
| 191 |
-
# WARNING: NDIF returns caches on CPU, so cross-turn cache reuse is not stable.
|
| 192 |
|
|
|
|
|
|
|
| 193 |
with _seeded_rng(seed if do_sample and not remote else None):
|
| 194 |
-
with model.generate(prompt, **generation_kwargs) as tracer:
|
| 195 |
generated = tracer.result.save()
|
| 196 |
|
| 197 |
if hasattr(generated, "value") and getattr(generated, "value") is not None:
|
| 198 |
generated = generated.value
|
| 199 |
|
| 200 |
-
|
| 201 |
-
raise ValueError("Generation did not return token sequences")
|
| 202 |
-
|
| 203 |
-
sequences = generated.sequences
|
| 204 |
if not isinstance(sequences, torch.Tensor):
|
| 205 |
raise TypeError("Generated sequences must be a tensor")
|
| 206 |
|
|
|
|
| 1 |
import logging
|
| 2 |
+
from contextlib import contextmanager
|
| 3 |
from dataclasses import dataclass
|
| 4 |
from typing import Literal
|
| 5 |
|
|
|
|
| 174 |
|
| 175 |
generation_kwargs: dict[str, object] = {
|
| 176 |
"max_new_tokens": max_new_tokens,
|
|
|
|
| 177 |
"use_cache": True,
|
| 178 |
}
|
| 179 |
+
if not remote:
|
| 180 |
+
# No need for this in remote which also slows down download drastically
|
| 181 |
+
generation_kwargs["return_dict_in_generate"] = True
|
| 182 |
if do_sample:
|
| 183 |
generation_kwargs["do_sample"] = True
|
| 184 |
generation_kwargs["temperature"] = temperature
|
|
|
|
| 188 |
generation_kwargs["repetition_penalty"] = repetition_penalty
|
| 189 |
if past_key_values is not None and not remote:
|
| 190 |
generation_kwargs["past_key_values"] = past_key_values
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
+
# `remote` is captured by nnsight's RemoteableMixin.trace() and is NOT
|
| 193 |
+
# forwarded to the underlying model's generate
|
| 194 |
with _seeded_rng(seed if do_sample and not remote else None):
|
| 195 |
+
with model.generate(prompt, remote=remote, **generation_kwargs) as tracer:
|
| 196 |
generated = tracer.result.save()
|
| 197 |
|
| 198 |
if hasattr(generated, "value") and getattr(generated, "value") is not None:
|
| 199 |
generated = generated.value
|
| 200 |
|
| 201 |
+
sequences = generated.sequences if hasattr(generated, "sequences") else generated
|
|
|
|
|
|
|
|
|
|
| 202 |
if not isinstance(sequences, torch.Tensor):
|
| 203 |
raise TypeError("Generated sequences must be a tensor")
|
| 204 |
|
uv.lock
CHANGED
|
@@ -1207,7 +1207,7 @@ requires-dist = [
|
|
| 1207 |
|
| 1208 |
[[package]]
|
| 1209 |
name = "persona-vectors"
|
| 1210 |
-
version = "0.1.
|
| 1211 |
source = { registry = "https://pypi.org/simple" }
|
| 1212 |
dependencies = [
|
| 1213 |
{ name = "kaleido" },
|
|
@@ -1224,9 +1224,9 @@ dependencies = [
|
|
| 1224 |
{ name = "transformers" },
|
| 1225 |
{ name = "umap-learn" },
|
| 1226 |
]
|
| 1227 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
| 1228 |
wheels = [
|
| 1229 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 1230 |
]
|
| 1231 |
|
| 1232 |
[[package]]
|
|
|
|
| 1207 |
|
| 1208 |
[[package]]
|
| 1209 |
name = "persona-vectors"
|
| 1210 |
+
version = "0.1.2"
|
| 1211 |
source = { registry = "https://pypi.org/simple" }
|
| 1212 |
dependencies = [
|
| 1213 |
{ name = "kaleido" },
|
|
|
|
| 1224 |
{ name = "transformers" },
|
| 1225 |
{ name = "umap-learn" },
|
| 1226 |
]
|
| 1227 |
+
sdist = { url = "https://files.pythonhosted.org/packages/86/ce/4bd6a69dd268ddb7eebf57e1d770a706483682c1aac77181502f94787b45/persona_vectors-0.1.2.tar.gz", hash = "sha256:8f14c5839e619e6a5e3902e54d5335bd58e92dd5d4d1559c5aabc5417084aacd", size = 10942, upload-time = "2026-04-09T10:30:10.731Z" }
|
| 1228 |
wheels = [
|
| 1229 |
+
{ url = "https://files.pythonhosted.org/packages/6d/a3/0e033b727f288564c166c9ef15a338119f52dea5fe2886970c34c03d951b/persona_vectors-0.1.2-py3-none-any.whl", hash = "sha256:dda78fbdf0815bc49d4069b16f4b85526a775637b85a14ba1be87dfb5f1f280c", size = 14451, upload-time = "2026-04-09T10:30:09.668Z" },
|
| 1230 |
]
|
| 1231 |
|
| 1232 |
[[package]]
|