ukung commited on
Commit
f05bc67
Β·
verified Β·
1 Parent(s): 9a517d7

Upload test.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. test.py +118 -0
test.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test TinyV4 base model
3
+ β€” load dari HuggingFace Hub (ukung/tinyv4) atau dari folder lokal
4
+ β€” forward pass, generate text (ID & EN)
5
+ """
6
+ import torch
7
+ import json
8
+ import os
9
+ import sys
10
+
11
+ # ---------------------------------------------------------------------------
12
+ # 0. Config β€” ganti ke False kalau mau test dari folder lokal
13
+ # ---------------------------------------------------------------------------
14
+ USE_HUB = True
15
+ HF_REPO = "ukung/tinyv4"
16
+
17
+ if USE_HUB:
18
+ # Load dari HuggingFace Hub (trust_remote_code=True)
19
+ from transformers import AutoTokenizer, AutoModel
20
+ tokenizer = AutoTokenizer.from_pretrained(HF_REPO)
21
+ model = AutoModel.from_pretrained(HF_REPO, trust_remote_code=True)
22
+ model.head.weight = model.embed.weight # tie embeddings
23
+ model.eval()
24
+ else:
25
+ # Load dari folder lokal
26
+ MODEL_DIR = os.path.dirname(os.path.abspath(__file__))
27
+ sys.path.insert(0, MODEL_DIR)
28
+ from modeling_tinyv4 import TinyV4, TinyV4Config
29
+ from transformers import AutoTokenizer
30
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
31
+ model = TinyV4.from_pretrained(MODEL_DIR)
32
+ model.head.weight = model.embed.weight # tie embeddings
33
+ model.eval()
34
+
35
+ n_params = sum(p.numel() for p in model.parameters())
36
+ print(f"βœ… Model loaded: {n_params:,} params ({n_params/1e6:.2f}M)")
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # 2. Config info
40
+ # ---------------------------------------------------------------------------
41
+ cfg = model.config
42
+ print(f"βœ… Config: dim={cfg.dim}, depth={cfg.depth}, vocab={cfg.vocab_size}")
43
+ print(f" MoE: {cfg.n_routed} routed + {cfg.n_shared} shared, {cfg.n_active} active")
44
+ print(f" MTP: depth={cfg.mtp_depth}, max_len={cfg.max_len}")
45
+
46
+ # ---------------------------------------------------------------------------
47
+ # 3. Tie check
48
+ # ---------------------------------------------------------------------------
49
+ assert model.head.weight.data_ptr() == model.embed.weight.data_ptr(), "❌ Embedding tie FAILED!"
50
+ print("βœ… Embedding tie: OK")
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # 4. Forward pass (smoke test)
54
+ # ---------------------------------------------------------------------------
55
+ dummy = torch.randint(0, cfg.vocab_size, (2, 64))
56
+ with torch.no_grad():
57
+ logits, mtp, bal = model(dummy)
58
+
59
+ has_nan = torch.isnan(logits).any().item()
60
+ has_inf = torch.isinf(logits).any().item()
61
+ print(f"βœ… Forward pass: logits={logits.shape}, NaN={has_nan}, Inf={has_inf}")
62
+ if mtp is not None:
63
+ print(f" MTP logits: {mtp.shape}, NaN={torch.isnan(mtp).any().item()}")
64
+ print(f" Balance loss: {bal.item():.6f}" if bal is not None else " Balance loss: None")
65
+
66
+ # ---------------------------------------------------------------------------
67
+ # 5. Generate text
68
+ # ---------------------------------------------------------------------------
69
+ @torch.no_grad()
70
+ def generate(prompt, max_new_tokens=60, temperature=0.8, top_k=40):
71
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
72
+
73
+ for _ in range(max_new_tokens):
74
+ idx = input_ids[:, -cfg.max_len:]
75
+ logits, _, _ = model(idx)
76
+ logits = logits[:, -1, :] / temperature
77
+
78
+ # Top-k filter
79
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
80
+ logits[logits < v[:, [-1]]] = float("-inf")
81
+
82
+ probs = torch.softmax(logits, dim=-1)
83
+
84
+ # Cek NaN di probs β€” fallback ke uniform
85
+ if torch.isnan(probs).any() or torch.isinf(probs).any():
86
+ probs = torch.ones_like(probs) / probs.size(-1)
87
+
88
+ next_token = torch.multinomial(probs, 1)
89
+ input_ids = torch.cat([input_ids, next_token], dim=1)
90
+
91
+ if next_token.item() == tokenizer.eos_token_id:
92
+ break
93
+
94
+ return tokenizer.decode(input_ids[0], skip_special_tokens=True)
95
+
96
+ print()
97
+ print("=" * 60)
98
+ print("πŸ“ GENERATION TEST")
99
+ print("=" * 60)
100
+
101
+ prompts = [
102
+ ("EN", "Once upon a time,"),
103
+ ("EN", "There was a little"),
104
+ ("EN", "In a small village,"),
105
+ ("ID", "Pada suatu hari,"),
106
+ ("ID", "Di sebuah desa kecil,"),
107
+ ("ID", "Alkisah, tersebutlah"),
108
+ ]
109
+
110
+ for lang, prompt in prompts:
111
+ output = generate(prompt, max_new_tokens=50, temperature=0.8)
112
+ print(f" [{lang}] {prompt}")
113
+ print(f" β†’ {output}")
114
+ print()
115
+
116
+ print("=" * 60)
117
+ print("βœ… ALL TESTS PASSED")
118
+ print("=" * 60)