LisaMegaWatts commited on
Commit
c79eabb
Β·
verified Β·
1 Parent(s): 280ed9e

fix: real token-by-token streaming (was generating all tokens then splitting by spaces)

Browse files
Files changed (1) hide show
  1. server.py +21 -27
server.py CHANGED
@@ -93,37 +93,33 @@ print(f" Organelles: {MODEL_CONFIG.organelles}")
93
 
94
 
95
  @torch.no_grad()
96
- def generate(
97
  prompt: str,
98
  max_tokens: int = 200,
99
  temperature: float = 0.8,
100
  top_k: int = 40,
101
  top_p: float = 1.0,
102
- on_token=None,
103
- ) -> str:
104
  tokens = tokenizer.encode(prompt)
105
  if not tokens:
106
  tokens = [0]
107
  idx = torch.tensor([tokens], dtype=torch.long)
108
- generated_ids = []
109
 
110
  for _ in range(max_tokens):
111
  idx_cond = idx[:, -MODEL_CONFIG.context_length:]
112
  logits = model(idx_cond)
113
  logits_last = logits[0, -1, :].float()
114
 
115
- # Temperature
116
  if temperature > 0.01:
117
  logits_last = logits_last / temperature
118
  else:
119
  logits_last = logits_last / 0.01
120
 
121
- # Top-k
122
  if 0 < top_k < logits_last.size(0):
123
  threshold = torch.topk(logits_last, top_k).values[-1]
124
  logits_last[logits_last < threshold] = float("-inf")
125
 
126
- # Top-p
127
  if top_p < 1.0:
128
  sorted_logits, sorted_indices = torch.sort(logits_last, descending=True)
129
  probs_sorted = F.softmax(sorted_logits, dim=-1)
@@ -134,14 +130,20 @@ def generate(
134
 
135
  probs = F.softmax(logits_last, dim=-1)
136
  next_id = torch.multinomial(probs, 1).item()
137
- generated_ids.append(next_id)
138
  idx = torch.cat([idx, torch.tensor([[next_id]])], dim=1)
 
139
 
140
- if on_token is not None:
141
- token_str = tokenizer.decode([next_id])
142
- on_token(token_str)
143
 
144
- return tokenizer.decode(generated_ids)
 
 
 
 
 
 
 
 
 
145
 
146
 
147
  # ═══════════════════════════════════════════════════════════════════
@@ -231,7 +233,6 @@ async def chat_completions(request: Request):
231
  import json as json_mod
232
 
233
  def sse_stream():
234
- # Initial chunk
235
  initial = {
236
  "id": completion_id,
237
  "object": "chat.completion.chunk",
@@ -242,27 +243,20 @@ async def chat_completions(request: Request):
242
  yield f"data: {json_mod.dumps(initial)}\n\n"
243
 
244
  token_count = 0
245
-
246
- def on_token(token_str):
247
- nonlocal token_count
 
248
  token_count += 1
249
-
250
- text = generate(prompt_text, max_tokens=max_tokens, temperature=temperature,
251
- top_k=top_k_val, top_p=top_p_val, on_token=on_token)
252
-
253
- # Send all generated text as chunks (word-level for readability)
254
- for word in text.split(" "):
255
- chunk_text = word + " " if word else ""
256
  chunk = {
257
  "id": completion_id,
258
  "object": "chat.completion.chunk",
259
  "created": created,
260
  "model": "symbiogpt-10m",
261
- "choices": [{"index": 0, "delta": {"content": chunk_text}, "finish_reason": None}],
262
  }
263
  yield f"data: {json_mod.dumps(chunk)}\n\n"
264
 
265
- # Final chunk
266
  finish = {
267
  "id": completion_id,
268
  "object": "chat.completion.chunk",
@@ -271,8 +265,8 @@ async def chat_completions(request: Request):
271
  "choices": [{"index": 0, "delta": {}, "finish_reason": "length" if token_count >= max_tokens else "stop"}],
272
  "usage": {
273
  "prompt_tokens": prompt_tokens,
274
- "completion_tokens": max_tokens,
275
- "total_tokens": prompt_tokens + max_tokens,
276
  },
277
  }
278
  yield f"data: {json_mod.dumps(finish)}\n\n"
 
93
 
94
 
95
  @torch.no_grad()
96
+ def generate_streaming(
97
  prompt: str,
98
  max_tokens: int = 200,
99
  temperature: float = 0.8,
100
  top_k: int = 40,
101
  top_p: float = 1.0,
102
+ ):
103
+ """Generator yielding token strings one at a time for real SSE streaming."""
104
  tokens = tokenizer.encode(prompt)
105
  if not tokens:
106
  tokens = [0]
107
  idx = torch.tensor([tokens], dtype=torch.long)
 
108
 
109
  for _ in range(max_tokens):
110
  idx_cond = idx[:, -MODEL_CONFIG.context_length:]
111
  logits = model(idx_cond)
112
  logits_last = logits[0, -1, :].float()
113
 
 
114
  if temperature > 0.01:
115
  logits_last = logits_last / temperature
116
  else:
117
  logits_last = logits_last / 0.01
118
 
 
119
  if 0 < top_k < logits_last.size(0):
120
  threshold = torch.topk(logits_last, top_k).values[-1]
121
  logits_last[logits_last < threshold] = float("-inf")
122
 
 
123
  if top_p < 1.0:
124
  sorted_logits, sorted_indices = torch.sort(logits_last, descending=True)
125
  probs_sorted = F.softmax(sorted_logits, dim=-1)
 
130
 
131
  probs = F.softmax(logits_last, dim=-1)
132
  next_id = torch.multinomial(probs, 1).item()
 
133
  idx = torch.cat([idx, torch.tensor([[next_id]])], dim=1)
134
+ yield tokenizer.decode([next_id])
135
 
 
 
 
136
 
137
+ @torch.no_grad()
138
+ def generate(
139
+ prompt: str,
140
+ max_tokens: int = 200,
141
+ temperature: float = 0.8,
142
+ top_k: int = 40,
143
+ top_p: float = 1.0,
144
+ ) -> str:
145
+ """Generate complete text (non-streaming wrapper)."""
146
+ return "".join(generate_streaming(prompt, max_tokens, temperature, top_k, top_p))
147
 
148
 
149
  # ═══════════════════════════════════════════════════════════════════
 
233
  import json as json_mod
234
 
235
  def sse_stream():
 
236
  initial = {
237
  "id": completion_id,
238
  "object": "chat.completion.chunk",
 
243
  yield f"data: {json_mod.dumps(initial)}\n\n"
244
 
245
  token_count = 0
246
+ for token_str in generate_streaming(
247
+ prompt_text, max_tokens=max_tokens, temperature=temperature,
248
+ top_k=top_k_val, top_p=top_p_val,
249
+ ):
250
  token_count += 1
 
 
 
 
 
 
 
251
  chunk = {
252
  "id": completion_id,
253
  "object": "chat.completion.chunk",
254
  "created": created,
255
  "model": "symbiogpt-10m",
256
+ "choices": [{"index": 0, "delta": {"content": token_str}, "finish_reason": None}],
257
  }
258
  yield f"data: {json_mod.dumps(chunk)}\n\n"
259
 
 
260
  finish = {
261
  "id": completion_id,
262
  "object": "chat.completion.chunk",
 
265
  "choices": [{"index": 0, "delta": {}, "finish_reason": "length" if token_count >= max_tokens else "stop"}],
266
  "usage": {
267
  "prompt_tokens": prompt_tokens,
268
+ "completion_tokens": token_count,
269
+ "total_tokens": prompt_tokens + token_count,
270
  },
271
  }
272
  yield f"data: {json_mod.dumps(finish)}\n\n"