fix: real token-by-token streaming (was generating all tokens then splitting by spaces)
Browse files
server.py
CHANGED
|
@@ -93,37 +93,33 @@ print(f" Organelles: {MODEL_CONFIG.organelles}")
|
|
| 93 |
|
| 94 |
|
| 95 |
@torch.no_grad()
|
| 96 |
-
def
|
| 97 |
prompt: str,
|
| 98 |
max_tokens: int = 200,
|
| 99 |
temperature: float = 0.8,
|
| 100 |
top_k: int = 40,
|
| 101 |
top_p: float = 1.0,
|
| 102 |
-
|
| 103 |
-
|
| 104 |
tokens = tokenizer.encode(prompt)
|
| 105 |
if not tokens:
|
| 106 |
tokens = [0]
|
| 107 |
idx = torch.tensor([tokens], dtype=torch.long)
|
| 108 |
-
generated_ids = []
|
| 109 |
|
| 110 |
for _ in range(max_tokens):
|
| 111 |
idx_cond = idx[:, -MODEL_CONFIG.context_length:]
|
| 112 |
logits = model(idx_cond)
|
| 113 |
logits_last = logits[0, -1, :].float()
|
| 114 |
|
| 115 |
-
# Temperature
|
| 116 |
if temperature > 0.01:
|
| 117 |
logits_last = logits_last / temperature
|
| 118 |
else:
|
| 119 |
logits_last = logits_last / 0.01
|
| 120 |
|
| 121 |
-
# Top-k
|
| 122 |
if 0 < top_k < logits_last.size(0):
|
| 123 |
threshold = torch.topk(logits_last, top_k).values[-1]
|
| 124 |
logits_last[logits_last < threshold] = float("-inf")
|
| 125 |
|
| 126 |
-
# Top-p
|
| 127 |
if top_p < 1.0:
|
| 128 |
sorted_logits, sorted_indices = torch.sort(logits_last, descending=True)
|
| 129 |
probs_sorted = F.softmax(sorted_logits, dim=-1)
|
|
@@ -134,14 +130,20 @@ def generate(
|
|
| 134 |
|
| 135 |
probs = F.softmax(logits_last, dim=-1)
|
| 136 |
next_id = torch.multinomial(probs, 1).item()
|
| 137 |
-
generated_ids.append(next_id)
|
| 138 |
idx = torch.cat([idx, torch.tensor([[next_id]])], dim=1)
|
|
|
|
| 139 |
|
| 140 |
-
if on_token is not None:
|
| 141 |
-
token_str = tokenizer.decode([next_id])
|
| 142 |
-
on_token(token_str)
|
| 143 |
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
|
| 147 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -231,7 +233,6 @@ async def chat_completions(request: Request):
|
|
| 231 |
import json as json_mod
|
| 232 |
|
| 233 |
def sse_stream():
|
| 234 |
-
# Initial chunk
|
| 235 |
initial = {
|
| 236 |
"id": completion_id,
|
| 237 |
"object": "chat.completion.chunk",
|
|
@@ -242,27 +243,20 @@ async def chat_completions(request: Request):
|
|
| 242 |
yield f"data: {json_mod.dumps(initial)}\n\n"
|
| 243 |
|
| 244 |
token_count = 0
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
|
|
|
| 248 |
token_count += 1
|
| 249 |
-
|
| 250 |
-
text = generate(prompt_text, max_tokens=max_tokens, temperature=temperature,
|
| 251 |
-
top_k=top_k_val, top_p=top_p_val, on_token=on_token)
|
| 252 |
-
|
| 253 |
-
# Send all generated text as chunks (word-level for readability)
|
| 254 |
-
for word in text.split(" "):
|
| 255 |
-
chunk_text = word + " " if word else ""
|
| 256 |
chunk = {
|
| 257 |
"id": completion_id,
|
| 258 |
"object": "chat.completion.chunk",
|
| 259 |
"created": created,
|
| 260 |
"model": "symbiogpt-10m",
|
| 261 |
-
"choices": [{"index": 0, "delta": {"content":
|
| 262 |
}
|
| 263 |
yield f"data: {json_mod.dumps(chunk)}\n\n"
|
| 264 |
|
| 265 |
-
# Final chunk
|
| 266 |
finish = {
|
| 267 |
"id": completion_id,
|
| 268 |
"object": "chat.completion.chunk",
|
|
@@ -271,8 +265,8 @@ async def chat_completions(request: Request):
|
|
| 271 |
"choices": [{"index": 0, "delta": {}, "finish_reason": "length" if token_count >= max_tokens else "stop"}],
|
| 272 |
"usage": {
|
| 273 |
"prompt_tokens": prompt_tokens,
|
| 274 |
-
"completion_tokens":
|
| 275 |
-
"total_tokens": prompt_tokens +
|
| 276 |
},
|
| 277 |
}
|
| 278 |
yield f"data: {json_mod.dumps(finish)}\n\n"
|
|
|
|
| 93 |
|
| 94 |
|
| 95 |
@torch.no_grad()
|
| 96 |
+
def generate_streaming(
|
| 97 |
prompt: str,
|
| 98 |
max_tokens: int = 200,
|
| 99 |
temperature: float = 0.8,
|
| 100 |
top_k: int = 40,
|
| 101 |
top_p: float = 1.0,
|
| 102 |
+
):
|
| 103 |
+
"""Generator yielding token strings one at a time for real SSE streaming."""
|
| 104 |
tokens = tokenizer.encode(prompt)
|
| 105 |
if not tokens:
|
| 106 |
tokens = [0]
|
| 107 |
idx = torch.tensor([tokens], dtype=torch.long)
|
|
|
|
| 108 |
|
| 109 |
for _ in range(max_tokens):
|
| 110 |
idx_cond = idx[:, -MODEL_CONFIG.context_length:]
|
| 111 |
logits = model(idx_cond)
|
| 112 |
logits_last = logits[0, -1, :].float()
|
| 113 |
|
|
|
|
| 114 |
if temperature > 0.01:
|
| 115 |
logits_last = logits_last / temperature
|
| 116 |
else:
|
| 117 |
logits_last = logits_last / 0.01
|
| 118 |
|
|
|
|
| 119 |
if 0 < top_k < logits_last.size(0):
|
| 120 |
threshold = torch.topk(logits_last, top_k).values[-1]
|
| 121 |
logits_last[logits_last < threshold] = float("-inf")
|
| 122 |
|
|
|
|
| 123 |
if top_p < 1.0:
|
| 124 |
sorted_logits, sorted_indices = torch.sort(logits_last, descending=True)
|
| 125 |
probs_sorted = F.softmax(sorted_logits, dim=-1)
|
|
|
|
| 130 |
|
| 131 |
probs = F.softmax(logits_last, dim=-1)
|
| 132 |
next_id = torch.multinomial(probs, 1).item()
|
|
|
|
| 133 |
idx = torch.cat([idx, torch.tensor([[next_id]])], dim=1)
|
| 134 |
+
yield tokenizer.decode([next_id])
|
| 135 |
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
+
@torch.no_grad()
|
| 138 |
+
def generate(
|
| 139 |
+
prompt: str,
|
| 140 |
+
max_tokens: int = 200,
|
| 141 |
+
temperature: float = 0.8,
|
| 142 |
+
top_k: int = 40,
|
| 143 |
+
top_p: float = 1.0,
|
| 144 |
+
) -> str:
|
| 145 |
+
"""Generate complete text (non-streaming wrapper)."""
|
| 146 |
+
return "".join(generate_streaming(prompt, max_tokens, temperature, top_k, top_p))
|
| 147 |
|
| 148 |
|
| 149 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 233 |
import json as json_mod
|
| 234 |
|
| 235 |
def sse_stream():
|
|
|
|
| 236 |
initial = {
|
| 237 |
"id": completion_id,
|
| 238 |
"object": "chat.completion.chunk",
|
|
|
|
| 243 |
yield f"data: {json_mod.dumps(initial)}\n\n"
|
| 244 |
|
| 245 |
token_count = 0
|
| 246 |
+
for token_str in generate_streaming(
|
| 247 |
+
prompt_text, max_tokens=max_tokens, temperature=temperature,
|
| 248 |
+
top_k=top_k_val, top_p=top_p_val,
|
| 249 |
+
):
|
| 250 |
token_count += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
chunk = {
|
| 252 |
"id": completion_id,
|
| 253 |
"object": "chat.completion.chunk",
|
| 254 |
"created": created,
|
| 255 |
"model": "symbiogpt-10m",
|
| 256 |
+
"choices": [{"index": 0, "delta": {"content": token_str}, "finish_reason": None}],
|
| 257 |
}
|
| 258 |
yield f"data: {json_mod.dumps(chunk)}\n\n"
|
| 259 |
|
|
|
|
| 260 |
finish = {
|
| 261 |
"id": completion_id,
|
| 262 |
"object": "chat.completion.chunk",
|
|
|
|
| 265 |
"choices": [{"index": 0, "delta": {}, "finish_reason": "length" if token_count >= max_tokens else "stop"}],
|
| 266 |
"usage": {
|
| 267 |
"prompt_tokens": prompt_tokens,
|
| 268 |
+
"completion_tokens": token_count,
|
| 269 |
+
"total_tokens": prompt_tokens + token_count,
|
| 270 |
},
|
| 271 |
}
|
| 272 |
yield f"data: {json_mod.dumps(finish)}\n\n"
|