Aluode commited on
Commit
dff6149
·
verified ·
1 Parent(s): f76e339

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -2
app.py CHANGED
@@ -35,13 +35,21 @@ def generate_text(prompt, max_new_tokens=80, temperature=0.7):
35
 
36
  with torch.no_grad():
37
  for _ in range(max_new_tokens):
38
- # Pass the signal through the Moiré field
39
- logits, _ = model(input_ids)
 
 
 
 
 
 
40
  next_token_logits = logits[:, -1, :] / temperature
41
 
42
  # Sample the next token
43
  probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
44
  next_token = torch.multinomial(probs, num_samples=1)
 
 
45
  input_ids = torch.cat((input_ids, next_token), dim=1)
46
 
47
  # Stop if the field decides the thought is complete
 
35
 
36
  with torch.no_grad():
37
  for _ in range(max_new_tokens):
38
+ # CROP THE INPUT: Only look at the most recent max_seq_len tokens
39
+ # so the positional embeddings never go out of bounds (257)
40
+ cond_input = input_ids[:, -config.max_seq_len:]
41
+
42
+ # Pass the cropped signal through the Moiré field
43
+ logits, _ = model(cond_input)
44
+
45
+ # Grab the prediction for the last token
46
  next_token_logits = logits[:, -1, :] / temperature
47
 
48
  # Sample the next token
49
  probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
50
  next_token = torch.multinomial(probs, num_samples=1)
51
+
52
+ # Append it to the running sequence
53
  input_ids = torch.cat((input_ids, next_token), dim=1)
54
 
55
  # Stop if the field decides the thought is complete