Really-Amazing commited on
Commit
c424ad1
·
verified ·
1 Parent(s): 72950d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -16
app.py CHANGED
@@ -44,7 +44,7 @@ print("Toddler is awake and ready!")
44
 
45
  def chat_fn(message, history):
46
  try:
47
- # Build Chat History (Handling standard Gradio list-of-lists format)
48
  tokens = [tokenizer.bos_token_id]
49
  for user_msg, assistant_msg in history:
50
  if user_msg:
@@ -52,14 +52,14 @@ def chat_fn(message, history):
52
  if assistant_msg:
53
  tokens.extend([tokenizer.assistant_start_id] + tokenizer.encode(assistant_msg) + [tokenizer.assistant_end_id])
54
 
55
- # Add current user prompt
56
  tokens.extend([tokenizer.user_start_id] + tokenizer.encode(message) + [tokenizer.user_end_id])
57
  tokens.append(tokenizer.assistant_start_id)
58
 
59
  input_ids = torch.tensor([tokens], dtype=torch.long)
60
 
61
- # 4. Generate
62
  with torch.no_grad():
 
63
  output_ids = model.generate(
64
  input_ids,
65
  max_tokens=512,
@@ -67,28 +67,24 @@ def chat_fn(message, history):
67
  top_k=40
68
  )
69
 
70
- # Handle output
71
  if isinstance(output_ids, torch.Tensor):
 
72
  new_tokens = output_ids[0][input_ids.shape[1]:]
73
  response = tokenizer.decode(new_tokens.tolist())
74
  else:
75
- # Generator logic
76
- response = ""
77
- for token in output_ids:
78
- decoded = tokenizer.decode([token])
79
- if "<|assistant_end|>" in decoded:
80
- break
81
- response += decoded
82
- yield response
83
 
84
- # Final cleanup
85
- for tag in ["<|assistant_end|>", "<|end|>", "<|user_start|>"]:
86
  response = response.split(tag)[0]
87
 
88
- return response.strip()
 
89
 
90
  except Exception as e:
91
- print(f"ERROR: {e}")
92
  return f"Toddler tantrum: {str(e)}"
93
 
94
  # 5. Launch UI (Cleaned for Gradio 6.0 compatibility)
 
44
 
45
  def chat_fn(message, history):
46
  try:
47
+ # 1. Build Token List
48
  tokens = [tokenizer.bos_token_id]
49
  for user_msg, assistant_msg in history:
50
  if user_msg:
 
52
  if assistant_msg:
53
  tokens.extend([tokenizer.assistant_start_id] + tokenizer.encode(assistant_msg) + [tokenizer.assistant_end_id])
54
 
 
55
  tokens.extend([tokenizer.user_start_id] + tokenizer.encode(message) + [tokenizer.user_end_id])
56
  tokens.append(tokenizer.assistant_start_id)
57
 
58
  input_ids = torch.tensor([tokens], dtype=torch.long)
59
 
60
+ # 2. Generate (Non-streaming for stability)
61
  with torch.no_grad():
62
+ # In nanochat, generate usually returns the full sequence tensor
63
  output_ids = model.generate(
64
  input_ids,
65
  max_tokens=512,
 
67
  top_k=40
68
  )
69
 
70
+ # 3. Process Output
71
  if isinstance(output_ids, torch.Tensor):
72
+ # Slicing to get only new tokens
73
  new_tokens = output_ids[0][input_ids.shape[1]:]
74
  response = tokenizer.decode(new_tokens.tolist())
75
  else:
76
+ # If it's a generator, collect it all into one string
77
+ response = "".join([tokenizer.decode([t]) for t in output_ids])
 
 
 
 
 
 
78
 
79
+ # 4. Clean up tags
80
+ for tag in ["<|assistant_end|>", "<|end|>", "<|user_start|>", "<|bos|>"]:
81
  response = response.split(tag)[0]
82
 
83
+ final_text = response.strip()
84
+ return final_text if final_text else "..."
85
 
86
  except Exception as e:
87
+ print(f"CRITICAL ERROR: {e}")
88
  return f"Toddler tantrum: {str(e)}"
89
 
90
  # 5. Launch UI (Cleaned for Gradio 6.0 compatibility)