Upload folder using huggingface_hub
Browse files
app.py
CHANGED
|
@@ -408,44 +408,55 @@ so shuffling disrupts every digit's computation.
|
|
| 408 |
gr.Markdown("""### Using the models
|
| 409 |
|
| 410 |
All models are on [HuggingFace](https://huggingface.co/thoughtworks/arithmetic-sorl).
|
| 411 |
-
|
| 412 |
|
| 413 |
```python
|
|
|
|
| 414 |
from arithmetic.hub import load_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
from arithmetic.train import QWEN3_TOKEN_MAP, QWEN3_INV_MAP
|
| 416 |
from sorl.sorl_trainer import infer_insert_mask, insert_tokens_with_padding, expand_prompt_len
|
| 417 |
|
| 418 |
-
# Load model
|
| 419 |
-
model, config, metrics = load_model("add_sub_sorl_v1_abs30_K1_100K", device="cuda")
|
| 420 |
base_v = model.vocab_sizes[0].item()
|
| 421 |
|
| 422 |
# Encode: 123456+654321=
|
| 423 |
-
|
| 424 |
-
qwen_ids = torch.tensor([QWEN3_TOKEN_MAP[t] for t in
|
| 425 |
|
| 426 |
-
#
|
| 427 |
-
seq = qwen_ids.
|
| 428 |
-
|
|
|
|
| 429 |
ep = expand_prompt_len(torch.tensor([14], device="cuda"), im)
|
| 430 |
-
ed, ea = insert_tokens_with_padding(
|
| 431 |
|
| 432 |
-
# Recursion fills abstraction tokens
|
| 433 |
data, ppt, logits = model.recursion(ed, ea, max_iterations=2,
|
| 434 |
memory_span_abs=1792, memory_span_traj=1792, temperature=0.0, prompt_len=ep)
|
| 435 |
|
| 436 |
# Separate trajectory vs abstraction tokens
|
| 437 |
is_abs = data[0] >= base_v
|
| 438 |
-
trajectory = data[0][~is_abs] # real digit tokens
|
| 439 |
abstractions = data[0][is_abs] - base_v # abstraction token IDs (0-indexed)
|
| 440 |
-
|
| 441 |
-
# Decode answer
|
| 442 |
-
answer = [QWEN3_INV_MAP[t.item()] for t in trajectory[14:]] # skip prompt
|
| 443 |
-
print(f"Answer: {''.join(str(d) for d in answer)}")
|
| 444 |
print(f"Abstraction tokens: {abstractions.tolist()}")
|
|
|
|
| 445 |
```
|
| 446 |
|
| 447 |
Token IDs: `0-9` = digits, `10` = `+`, `11` = `-`, `12` = `=`.
|
| 448 |
-
Abstraction tokens are integers from
|
| 449 |
""")
|
| 450 |
|
| 451 |
|
|
|
|
| 408 |
gr.Markdown("""### Using the models
|
| 409 |
|
| 410 |
All models are on [HuggingFace](https://huggingface.co/thoughtworks/arithmetic-sorl).
|
| 411 |
+
Code is on the [`amir/arithmetic`](https://github.com/fangyuan-ksgk/mod_gpt/tree/amir/arithmetic) branch.
|
| 412 |
|
| 413 |
```python
|
| 414 |
+
import torch
|
| 415 |
from arithmetic.hub import load_model
|
| 416 |
+
from arithmetic.evaluate import ArithmeticEvaluator
|
| 417 |
+
from transformers import AutoTokenizer
|
| 418 |
+
|
| 419 |
+
# Load model + tokenizer
|
| 420 |
+
model, config, metrics = load_model("add_sub_sorl_v1_abs30_K1_100K", device="cuda")
|
| 421 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
| 422 |
+
|
| 423 |
+
# Run full evaluation with per-split accuracy
|
| 424 |
+
evaluator = ArithmeticEvaluator(model, tokenizer, device="cuda")
|
| 425 |
+
results = evaluator.run(ops="add_sub", K=1, n_per_split=100) # K=None for baseline
|
| 426 |
+
evaluator.print_table(results)
|
| 427 |
+
```
|
| 428 |
+
|
| 429 |
+
To inspect abstraction tokens on a single example:
|
| 430 |
+
|
| 431 |
+
```python
|
| 432 |
from arithmetic.train import QWEN3_TOKEN_MAP, QWEN3_INV_MAP
|
| 433 |
from sorl.sorl_trainer import infer_insert_mask, insert_tokens_with_padding, expand_prompt_len
|
| 434 |
|
|
|
|
|
|
|
| 435 |
base_v = model.vocab_sizes[0].item()
|
| 436 |
|
| 437 |
# Encode: 123456+654321=
|
| 438 |
+
prompt = [1,2,3,4,5,6, 10, 6,5,4,3,2,1, 12]
|
| 439 |
+
qwen_ids = torch.tensor([QWEN3_TOKEN_MAP[t] for t in prompt], device="cuda")
|
| 440 |
|
| 441 |
+
# Pad to full 21 tokens (14 prompt + 7 dummy answer), insert abstractions, recurse
|
| 442 |
+
seq = torch.cat([qwen_ids, torch.zeros(7, dtype=torch.long, device="cuda")])
|
| 443 |
+
ids = seq.unsqueeze(0)
|
| 444 |
+
im = infer_insert_mask(ids, K=1, attention_mask=torch.ones_like(ids))
|
| 445 |
ep = expand_prompt_len(torch.tensor([14], device="cuda"), im)
|
| 446 |
+
ed, ea = insert_tokens_with_padding(ids, torch.ones_like(ids), im, model.vocab_sizes[0], 151643)
|
| 447 |
|
|
|
|
| 448 |
data, ppt, logits = model.recursion(ed, ea, max_iterations=2,
|
| 449 |
memory_span_abs=1792, memory_span_traj=1792, temperature=0.0, prompt_len=ep)
|
| 450 |
|
| 451 |
# Separate trajectory vs abstraction tokens
|
| 452 |
is_abs = data[0] >= base_v
|
|
|
|
| 453 |
abstractions = data[0][is_abs] - base_v # abstraction token IDs (0-indexed)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
print(f"Abstraction tokens: {abstractions.tolist()}")
|
| 455 |
+
# Each abstraction token encodes carry/borrow state at that position
|
| 456 |
```
|
| 457 |
|
| 458 |
Token IDs: `0-9` = digits, `10` = `+`, `11` = `-`, `12` = `=`.
|
| 459 |
+
Abstraction tokens are integers from 1 to `abs_vocab` (0 is the placeholder before recursion).
|
| 460 |
""")
|
| 461 |
|
| 462 |
|