Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -44,7 +44,7 @@ print("Toddler is awake and ready!")
|
|
| 44 |
|
| 45 |
def chat_fn(message, history):
|
| 46 |
try:
|
| 47 |
-
# Build
|
| 48 |
tokens = [tokenizer.bos_token_id]
|
| 49 |
for user_msg, assistant_msg in history:
|
| 50 |
if user_msg:
|
|
@@ -52,14 +52,14 @@ def chat_fn(message, history):
|
|
| 52 |
if assistant_msg:
|
| 53 |
tokens.extend([tokenizer.assistant_start_id] + tokenizer.encode(assistant_msg) + [tokenizer.assistant_end_id])
|
| 54 |
|
| 55 |
-
# Add current user prompt
|
| 56 |
tokens.extend([tokenizer.user_start_id] + tokenizer.encode(message) + [tokenizer.user_end_id])
|
| 57 |
tokens.append(tokenizer.assistant_start_id)
|
| 58 |
|
| 59 |
input_ids = torch.tensor([tokens], dtype=torch.long)
|
| 60 |
|
| 61 |
-
#
|
| 62 |
with torch.no_grad():
|
|
|
|
| 63 |
output_ids = model.generate(
|
| 64 |
input_ids,
|
| 65 |
max_tokens=512,
|
|
@@ -67,28 +67,24 @@ def chat_fn(message, history):
|
|
| 67 |
top_k=40
|
| 68 |
)
|
| 69 |
|
| 70 |
-
#
|
| 71 |
if isinstance(output_ids, torch.Tensor):
|
|
|
|
| 72 |
new_tokens = output_ids[0][input_ids.shape[1]:]
|
| 73 |
response = tokenizer.decode(new_tokens.tolist())
|
| 74 |
else:
|
| 75 |
-
#
|
| 76 |
-
response = ""
|
| 77 |
-
for token in output_ids:
|
| 78 |
-
decoded = tokenizer.decode([token])
|
| 79 |
-
if "<|assistant_end|>" in decoded:
|
| 80 |
-
break
|
| 81 |
-
response += decoded
|
| 82 |
-
yield response
|
| 83 |
|
| 84 |
-
#
|
| 85 |
-
for tag in ["<|assistant_end|>", "<|end|>", "<|user_start|>"]:
|
| 86 |
response = response.split(tag)[0]
|
| 87 |
|
| 88 |
-
|
|
|
|
| 89 |
|
| 90 |
except Exception as e:
|
| 91 |
-
print(f"ERROR: {e}")
|
| 92 |
return f"Toddler tantrum: {str(e)}"
|
| 93 |
|
| 94 |
# 5. Launch UI (Cleaned for Gradio 6.0 compatibility)
|
|
|
|
| 44 |
|
| 45 |
def chat_fn(message, history):
|
| 46 |
try:
|
| 47 |
+
# 1. Build Token List
|
| 48 |
tokens = [tokenizer.bos_token_id]
|
| 49 |
for user_msg, assistant_msg in history:
|
| 50 |
if user_msg:
|
|
|
|
| 52 |
if assistant_msg:
|
| 53 |
tokens.extend([tokenizer.assistant_start_id] + tokenizer.encode(assistant_msg) + [tokenizer.assistant_end_id])
|
| 54 |
|
|
|
|
| 55 |
tokens.extend([tokenizer.user_start_id] + tokenizer.encode(message) + [tokenizer.user_end_id])
|
| 56 |
tokens.append(tokenizer.assistant_start_id)
|
| 57 |
|
| 58 |
input_ids = torch.tensor([tokens], dtype=torch.long)
|
| 59 |
|
| 60 |
+
# 2. Generate (Non-streaming for stability)
|
| 61 |
with torch.no_grad():
|
| 62 |
+
# In nanochat, generate usually returns the full sequence tensor
|
| 63 |
output_ids = model.generate(
|
| 64 |
input_ids,
|
| 65 |
max_tokens=512,
|
|
|
|
| 67 |
top_k=40
|
| 68 |
)
|
| 69 |
|
| 70 |
+
# 3. Process Output
|
| 71 |
if isinstance(output_ids, torch.Tensor):
|
| 72 |
+
# Slicing to get only new tokens
|
| 73 |
new_tokens = output_ids[0][input_ids.shape[1]:]
|
| 74 |
response = tokenizer.decode(new_tokens.tolist())
|
| 75 |
else:
|
| 76 |
+
# If it's a generator, collect it all into one string
|
| 77 |
+
response = "".join([tokenizer.decode([t]) for t in output_ids])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
+
# 4. Clean up tags
|
| 80 |
+
for tag in ["<|assistant_end|>", "<|end|>", "<|user_start|>", "<|bos|>"]:
|
| 81 |
response = response.split(tag)[0]
|
| 82 |
|
| 83 |
+
final_text = response.strip()
|
| 84 |
+
return final_text if final_text else "..."
|
| 85 |
|
| 86 |
except Exception as e:
|
| 87 |
+
print(f"CRITICAL ERROR: {e}")
|
| 88 |
return f"Toddler tantrum: {str(e)}"
|
| 89 |
|
| 90 |
# 5. Launch UI (Cleaned for Gradio 6.0 compatibility)
|