File size: 966 Bytes
54c5666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
from src.models.ultrathink import UltraThinkModel, UltraThinkConfig
from src.models.architecture import ModelConfig


def tiny_model():
    cfg = UltraThinkConfig(
        model_config=ModelConfig(
            vocab_size=256,
            n_positions=64,
            n_embd=64,
            n_layer=2,
            n_head=4,
            n_kv_head=4,
            intermediate_size=128,
            activation="relu",
            dropout=0.0,
            attention_dropout=0.0,
            flash_attention=False,
            gradient_checkpointing=False,
        )
    )
    return UltraThinkModel(cfg)


def test_forward_smoke():
    model = tiny_model()
    model.eval()
    input_ids = torch.randint(0, 256, (2, 16))
    attn = torch.ones_like(input_ids)
    with torch.no_grad():
        out = model(input_ids=input_ids, attention_mask=attn, labels=input_ids)
    assert "loss" in out and torch.isfinite(out["loss"])