amirali1985 commited on
Commit
11c4cd3
·
verified ·
1 Parent(s): a6e6b30

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +27 -16
app.py CHANGED
@@ -408,44 +408,55 @@ so shuffling disrupts every digit's computation.
408
  gr.Markdown("""### Using the models
409
 
410
  All models are on [HuggingFace](https://huggingface.co/thoughtworks/arithmetic-sorl).
411
- To load a model and run inference:
412
 
413
  ```python
 
414
  from arithmetic.hub import load_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  from arithmetic.train import QWEN3_TOKEN_MAP, QWEN3_INV_MAP
416
  from sorl.sorl_trainer import infer_insert_mask, insert_tokens_with_padding, expand_prompt_len
417
 
418
- # Load model
419
- model, config, metrics = load_model("add_sub_sorl_v1_abs30_K1_100K", device="cuda")
420
  base_v = model.vocab_sizes[0].item()
421
 
422
  # Encode: 123456+654321=
423
- tokens = [1,2,3,4,5,6, 10, 6,5,4,3,2,1, 12] # internal token IDs
424
- qwen_ids = torch.tensor([QWEN3_TOKEN_MAP[t] for t in tokens], device="cuda")
425
 
426
- # Insert abstraction tokens (K=1 = every position)
427
- seq = qwen_ids.unsqueeze(0)
428
- im = infer_insert_mask(seq, K=1, attention_mask=torch.ones_like(seq))
 
429
  ep = expand_prompt_len(torch.tensor([14], device="cuda"), im)
430
- ed, ea = insert_tokens_with_padding(seq, torch.ones_like(seq), im, model.vocab_sizes[0], 151643)
431
 
432
- # Recursion fills abstraction tokens
433
  data, ppt, logits = model.recursion(ed, ea, max_iterations=2,
434
  memory_span_abs=1792, memory_span_traj=1792, temperature=0.0, prompt_len=ep)
435
 
436
  # Separate trajectory vs abstraction tokens
437
  is_abs = data[0] >= base_v
438
- trajectory = data[0][~is_abs] # real digit tokens
439
  abstractions = data[0][is_abs] - base_v # abstraction token IDs (0-indexed)
440
-
441
- # Decode answer
442
- answer = [QWEN3_INV_MAP[t.item()] for t in trajectory[14:]] # skip prompt
443
- print(f"Answer: {''.join(str(d) for d in answer)}")
444
  print(f"Abstraction tokens: {abstractions.tolist()}")
 
445
  ```
446
 
447
  Token IDs: `0-9` = digits, `10` = `+`, `11` = `-`, `12` = `=`.
448
- Abstraction tokens are integers from 0 to `abs_vocab-1`, where 0 is the placeholder.
449
  """)
450
 
451
 
 
408
  gr.Markdown("""### Using the models
409
 
410
  All models are on [HuggingFace](https://huggingface.co/thoughtworks/arithmetic-sorl).
411
+ Code is on the [`amir/arithmetic`](https://github.com/fangyuan-ksgk/mod_gpt/tree/amir/arithmetic) branch.
412
 
413
  ```python
414
+ import torch
415
  from arithmetic.hub import load_model
416
+ from arithmetic.evaluate import ArithmeticEvaluator
417
+ from transformers import AutoTokenizer
418
+
419
+ # Load model + tokenizer
420
+ model, config, metrics = load_model("add_sub_sorl_v1_abs30_K1_100K", device="cuda")
421
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
422
+
423
+ # Run full evaluation with per-split accuracy
424
+ evaluator = ArithmeticEvaluator(model, tokenizer, device="cuda")
425
+ results = evaluator.run(ops="add_sub", K=1, n_per_split=100) # K=None for baseline
426
+ evaluator.print_table(results)
427
+ ```
428
+
429
+ To inspect abstraction tokens on a single example:
430
+
431
+ ```python
432
  from arithmetic.train import QWEN3_TOKEN_MAP, QWEN3_INV_MAP
433
  from sorl.sorl_trainer import infer_insert_mask, insert_tokens_with_padding, expand_prompt_len
434
 
 
 
435
  base_v = model.vocab_sizes[0].item()
436
 
437
  # Encode: 123456+654321=
438
+ prompt = [1,2,3,4,5,6, 10, 6,5,4,3,2,1, 12]
439
+ qwen_ids = torch.tensor([QWEN3_TOKEN_MAP[t] for t in prompt], device="cuda")
440
 
441
+ # Pad to full 21 tokens (14 prompt + 7 dummy answer), insert abstractions, recurse
442
+ seq = torch.cat([qwen_ids, torch.zeros(7, dtype=torch.long, device="cuda")])
443
+ ids = seq.unsqueeze(0)
444
+ im = infer_insert_mask(ids, K=1, attention_mask=torch.ones_like(ids))
445
  ep = expand_prompt_len(torch.tensor([14], device="cuda"), im)
446
+ ed, ea = insert_tokens_with_padding(ids, torch.ones_like(ids), im, model.vocab_sizes[0], 151643)
447
 
 
448
  data, ppt, logits = model.recursion(ed, ea, max_iterations=2,
449
  memory_span_abs=1792, memory_span_traj=1792, temperature=0.0, prompt_len=ep)
450
 
451
  # Separate trajectory vs abstraction tokens
452
  is_abs = data[0] >= base_v
 
453
  abstractions = data[0][is_abs] - base_v # abstraction token IDs (0-indexed)
 
 
 
 
454
  print(f"Abstraction tokens: {abstractions.tolist()}")
455
+ # Each abstraction token encodes carry/borrow state at that position
456
  ```
457
 
458
  Token IDs: `0-9` = digits, `10` = `+`, `11` = `-`, `12` = `=`.
459
+ Abstraction tokens are integers from 1 to `abs_vocab` (0 is the placeholder before recursion).
460
  """)
461
 
462