Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -35,13 +35,21 @@ def generate_text(prompt, max_new_tokens=80, temperature=0.7):
|
|
| 35 |
|
| 36 |
with torch.no_grad():
|
| 37 |
for _ in range(max_new_tokens):
|
| 38 |
-
#
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
next_token_logits = logits[:, -1, :] / temperature
|
| 41 |
|
| 42 |
# Sample the next token
|
| 43 |
probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
|
| 44 |
next_token = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
|
|
| 45 |
input_ids = torch.cat((input_ids, next_token), dim=1)
|
| 46 |
|
| 47 |
# Stop if the field decides the thought is complete
|
|
|
|
| 35 |
|
| 36 |
with torch.no_grad():
|
| 37 |
for _ in range(max_new_tokens):
|
| 38 |
+
# CROP THE INPUT: Only look at the most recent max_seq_len tokens
|
| 39 |
+
# so the positional embeddings never go out of bounds (257)
|
| 40 |
+
cond_input = input_ids[:, -config.max_seq_len:]
|
| 41 |
+
|
| 42 |
+
# Pass the cropped signal through the Moiré field
|
| 43 |
+
logits, _ = model(cond_input)
|
| 44 |
+
|
| 45 |
+
# Grab the prediction for the last token
|
| 46 |
next_token_logits = logits[:, -1, :] / temperature
|
| 47 |
|
| 48 |
# Sample the next token
|
| 49 |
probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
|
| 50 |
next_token = torch.multinomial(probs, num_samples=1)
|
| 51 |
+
|
| 52 |
+
# Append it to the running sequence
|
| 53 |
input_ids = torch.cat((input_ids, next_token), dim=1)
|
| 54 |
|
| 55 |
# Stop if the field decides the thought is complete
|