File size: 4,312 Bytes
f05bc67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
Test TinyV4 base model
β€” load dari HuggingFace Hub (ukung/tinyv4) atau dari folder lokal
β€” forward pass, generate text (ID & EN)
"""
import torch
import json
import os
import sys

# ---------------------------------------------------------------------------
# 0. Config β€” ganti ke False kalau mau test dari folder lokal
# ---------------------------------------------------------------------------
USE_HUB = True
HF_REPO = "ukung/tinyv4"

if USE_HUB:
    # Load dari HuggingFace Hub (trust_remote_code=True)
    from transformers import AutoTokenizer, AutoModel
    tokenizer = AutoTokenizer.from_pretrained(HF_REPO)
    model = AutoModel.from_pretrained(HF_REPO, trust_remote_code=True)
    model.head.weight = model.embed.weight  # tie embeddings
    model.eval()
else:
    # Load dari folder lokal
    MODEL_DIR = os.path.dirname(os.path.abspath(__file__))
    sys.path.insert(0, MODEL_DIR)
    from modeling_tinyv4 import TinyV4, TinyV4Config
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
    model = TinyV4.from_pretrained(MODEL_DIR)
    model.head.weight = model.embed.weight  # tie embeddings
    model.eval()

n_params = sum(p.numel() for p in model.parameters())
print(f"βœ… Model loaded: {n_params:,} params ({n_params/1e6:.2f}M)")

# ---------------------------------------------------------------------------
# 2. Config info
# ---------------------------------------------------------------------------
cfg = model.config
print(f"βœ… Config: dim={cfg.dim}, depth={cfg.depth}, vocab={cfg.vocab_size}")
print(f"   MoE: {cfg.n_routed} routed + {cfg.n_shared} shared, {cfg.n_active} active")
print(f"   MTP: depth={cfg.mtp_depth}, max_len={cfg.max_len}")

# ---------------------------------------------------------------------------
# 3. Tie check
# ---------------------------------------------------------------------------
assert model.head.weight.data_ptr() == model.embed.weight.data_ptr(), "❌ Embedding tie FAILED!"
print("βœ… Embedding tie: OK")

# ---------------------------------------------------------------------------
# 4. Forward pass (smoke test)
# ---------------------------------------------------------------------------
dummy = torch.randint(0, cfg.vocab_size, (2, 64))
with torch.no_grad():
    logits, mtp, bal = model(dummy)

has_nan = torch.isnan(logits).any().item()
has_inf = torch.isinf(logits).any().item()
print(f"βœ… Forward pass: logits={logits.shape}, NaN={has_nan}, Inf={has_inf}")
if mtp is not None:
    print(f"   MTP logits: {mtp.shape}, NaN={torch.isnan(mtp).any().item()}")
print(f"   Balance loss: {bal.item():.6f}" if bal is not None else "   Balance loss: None")

# ---------------------------------------------------------------------------
# 5. Generate text
# ---------------------------------------------------------------------------
@torch.no_grad()
def generate(prompt, max_new_tokens=60, temperature=0.8, top_k=40):
    input_ids = tokenizer.encode(prompt, return_tensors="pt")

    for _ in range(max_new_tokens):
        idx = input_ids[:, -cfg.max_len:]
        logits, _, _ = model(idx)
        logits = logits[:, -1, :] / temperature

        # Top-k filter
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
        logits[logits < v[:, [-1]]] = float("-inf")

        probs = torch.softmax(logits, dim=-1)

        # Cek NaN di probs β€” fallback ke uniform
        if torch.isnan(probs).any() or torch.isinf(probs).any():
            probs = torch.ones_like(probs) / probs.size(-1)

        next_token = torch.multinomial(probs, 1)
        input_ids = torch.cat([input_ids, next_token], dim=1)

        if next_token.item() == tokenizer.eos_token_id:
            break

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

print()
print("=" * 60)
print("πŸ“ GENERATION TEST")
print("=" * 60)

prompts = [
    ("EN", "Once upon a time,"),
    ("EN", "There was a little"),
    ("EN", "In a small village,"),
    ("ID", "Pada suatu hari,"),
    ("ID", "Di sebuah desa kecil,"),
    ("ID", "Alkisah, tersebutlah"),
]

for lang, prompt in prompts:
    output = generate(prompt, max_new_tokens=50, temperature=0.8)
    print(f"  [{lang}] {prompt}")
    print(f"  β†’ {output}")
    print()

print("=" * 60)
print("βœ… ALL TESTS PASSED")
print("=" * 60)