File size: 5,610 Bytes
53663e6
 
 
 
ac1d593
53663e6
 
ac1d593
53663e6
 
 
 
 
 
 
 
 
 
 
 
 
ac1d593
53663e6
 
 
 
 
 
 
 
 
 
 
ac1d593
53663e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac1d593
53663e6
 
 
 
 
 
 
 
 
ac1d593
53663e6
 
 
ac1d593
53663e6
 
ac1d593
53663e6
 
ac1d593
53663e6
 
 
 
 
 
 
 
ac1d593
53663e6
 
 
 
 
 
 
 
ac1d593
53663e6
 
 
 
 
 
 
 
ac1d593
53663e6
 
 
 
 
 
 
ac1d593
53663e6
 
 
 
 
 
 
ac1d593
 
53663e6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torch.nn as nn
import unicodedata
import os
import gradio as gr
from transformers import PreTrainedTokenizerFast, PretrainedConfig, PreTrainedModel
from tokenizers import decoders

# 1. Re-define the Architecture Classes (identical to the training/test phase)
class IsaiConfig(PretrainedConfig):
    model_type = "isai"
    def __init__(self, vocab_size=32000, hidden_size=1024, intermediate_size=2816, num_hidden_layers=24, num_attention_heads=16, num_key_value_heads=16, hidden_act="silu", max_position_embeddings=2048, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, pad_token_id=0, bos_token_id=1, eos_token_id=2, **kwargs):
        super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.num_key_value_heads = num_key_value_heads
        self.max_position_embeddings = max_position_embeddings
        self.rms_norm_eps = rms_norm_eps

class IsaiRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps
    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

class IsaiForCausalLM(PreTrainedModel):
    config_class = IsaiConfig
    def __init__(self, config):
        super().__init__(config)
        self.model = nn.ModuleDict({
            "embed_tokens": nn.Embedding(config.vocab_size, config.hidden_size),
            "layers": nn.ModuleList([nn.ModuleDict({
                "input_layernorm": IsaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps),
                "post_attention_layernorm": IsaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps),
                "self_attn": nn.Linear(config.hidden_size, config.hidden_size, bias=False),
                "mlp": nn.ModuleDict({
                    "gate_proj": nn.Linear(config.hidden_size, config.intermediate_size, bias=False),
                    "up_proj": nn.Linear(config.hidden_size, config.intermediate_size, bias=False),
                    "down_proj": nn.Linear(config.intermediate_size, config.hidden_size, bias=False),
                })
            }) for _ in range(config.num_hidden_layers)]),
            "norm": IsaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        })
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.post_init()

    def forward(self, input_ids=None, **kwargs):
        hidden_states = self.model.embed_tokens(input_ids)
        for layer in self.model.layers:
            h = layer.input_layernorm(hidden_states)
            hidden_states = hidden_states + layer.self_attn(h)
            h = layer.post_attention_layernorm(hidden_states)
            hidden_states = hidden_states + layer.mlp.down_proj(nn.functional.silu(layer.mlp.gate_proj(h)) * layer.mlp.up_proj(h))
        logits = self.lm_head(self.model.norm(hidden_states))
        return logits

# 2. Load Model and Tokenizer
model_dir = "models/isai-v4.2"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = PreTrainedTokenizerFast.from_pretrained(model_dir)
tokenizer._tokenizer.decoder = decoders.ByteLevel() # Critical for jaso restoration

config = IsaiConfig.from_pretrained(model_dir)
model = IsaiForCausalLM(config).to(device)

# Prioritize safetensors
weights_path = os.path.join(model_dir, "model.safetensors")
if os.path.exists(weights_path):
    from safetensors.torch import load_file
    model.load_state_dict(load_file(weights_path))
else:
    model.load_state_dict(torch.load(os.path.join(model_dir, "pytorch_model.bin"), map_location=device))
model.eval()

# 3. Define the Prediction Logic with Jaso Processing
def predict(message, history):
    # A. NFD Decomposition (Input)
    decomposed_input = unicodedata.normalize('NFD', message)
    input_ids = tokenizer.encode(decomposed_input, return_tensors="pt").to(device)
    
    current_ids = input_ids
    max_new_tokens = 50

    # B. Generate tokens
    for _ in range(max_new_tokens):
        with torch.no_grad():
            logits = model(current_ids)
            next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0)
            current_ids = torch.cat([current_ids, next_token], dim=-1)
            if next_token.item() == tokenizer.eos_token_id:
                break

    # C. Decode and NFC Recomposition (Output)
    # Only decode the generated part
    generated_tokens = current_ids[0][input_ids.shape[1]:]
    raw_response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
    final_response = unicodedata.normalize('NFC', raw_response)
    
    return final_response

# 4. Create and Launch Gradio Interface
demo = gr.ChatInterface(
    fn=predict,
    title="isai-v4.2 Jaso-Level Chat",
    description="μžμ†Œ λ‹¨μœ„(NFD)둜 μ†Œν†΅ν•˜λŠ” μ΄ˆμ†Œν˜• 일상 λŒ€ν™” λͺ¨λΈμž…λ‹ˆλ‹€. μž…λ ₯은 μžλ™μœΌλ‘œ λΆ„ν•΄λ˜κ³  좜λ ₯은 λ‹€μ‹œ ν•œκΈ€λ‘œ μ‘°ν•©λ©λ‹ˆλ‹€.",
    examples=["μ•ˆλ…•? λ°˜κ°€μ›Œ.", "였늘 날씨가 μ–΄λ•Œ?", "λ„ˆμ˜ 이름은 뭐야?"]
)

if __name__ == "__main__":
    demo.launch(share=True)