import torch import torch.nn as nn import unicodedata import os import gradio as gr from transformers import PreTrainedTokenizerFast, PretrainedConfig, PreTrainedModel from tokenizers import decoders # 1. Re-define the Architecture Classes (identical to the training/test phase) class IsaiConfig(PretrainedConfig): model_type = "isai" def __init__(self, vocab_size=32000, hidden_size=1024, intermediate_size=2816, num_hidden_layers=24, num_attention_heads=16, num_key_value_heads=16, hidden_act="silu", max_position_embeddings=2048, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, pad_token_id=0, bos_token_id=1, eos_token_id=2, **kwargs): super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.max_position_embeddings = max_position_embeddings self.rms_norm_eps = rms_norm_eps class IsaiRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) class IsaiForCausalLM(PreTrainedModel): config_class = IsaiConfig def __init__(self, config): super().__init__(config) self.model = nn.ModuleDict({ "embed_tokens": nn.Embedding(config.vocab_size, config.hidden_size), "layers": nn.ModuleList([nn.ModuleDict({ "input_layernorm": IsaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps), "post_attention_layernorm": IsaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps), "self_attn": nn.Linear(config.hidden_size, config.hidden_size, bias=False), "mlp": nn.ModuleDict({ "gate_proj": nn.Linear(config.hidden_size, config.intermediate_size, bias=False), "up_proj": nn.Linear(config.hidden_size, config.intermediate_size, bias=False), "down_proj": nn.Linear(config.intermediate_size, config.hidden_size, bias=False), }) }) for _ in range(config.num_hidden_layers)]), "norm": IsaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) }) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() def forward(self, input_ids=None, **kwargs): hidden_states = self.model.embed_tokens(input_ids) for layer in self.model.layers: h = layer.input_layernorm(hidden_states) hidden_states = hidden_states + layer.self_attn(h) h = layer.post_attention_layernorm(hidden_states) hidden_states = hidden_states + layer.mlp.down_proj(nn.functional.silu(layer.mlp.gate_proj(h)) * layer.mlp.up_proj(h)) logits = self.lm_head(self.model.norm(hidden_states)) return logits # 2. Load Model and Tokenizer model_dir = "models/isai-v4.2" device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = PreTrainedTokenizerFast.from_pretrained(model_dir) tokenizer._tokenizer.decoder = decoders.ByteLevel() # Critical for jaso restoration config = IsaiConfig.from_pretrained(model_dir) model = IsaiForCausalLM(config).to(device) # Prioritize safetensors weights_path = os.path.join(model_dir, "model.safetensors") if os.path.exists(weights_path): from safetensors.torch import load_file model.load_state_dict(load_file(weights_path)) else: model.load_state_dict(torch.load(os.path.join(model_dir, "pytorch_model.bin"), map_location=device)) model.eval() # 3. Define the Prediction Logic with Jaso Processing def predict(message, history): # A. NFD Decomposition (Input) decomposed_input = unicodedata.normalize('NFD', message) input_ids = tokenizer.encode(decomposed_input, return_tensors="pt").to(device) current_ids = input_ids max_new_tokens = 50 # B. Generate tokens for _ in range(max_new_tokens): with torch.no_grad(): logits = model(current_ids) next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0) current_ids = torch.cat([current_ids, next_token], dim=-1) if next_token.item() == tokenizer.eos_token_id: break # C. Decode and NFC Recomposition (Output) # Only decode the generated part generated_tokens = current_ids[0][input_ids.shape[1]:] raw_response = tokenizer.decode(generated_tokens, skip_special_tokens=True) final_response = unicodedata.normalize('NFC', raw_response) return final_response # 4. Create and Launch Gradio Interface demo = gr.ChatInterface( fn=predict, title="isai-v4.2 Jaso-Level Chat", description="자소 단위(NFD)로 소통하는 초소형 일상 대화 모델입니다. 입력은 자동으로 분해되고 출력은 다시 한글로 조합됩니다.", examples=["안녕? 반가워.", "오늘 날씨가 어때?", "너의 이름은 뭐야?"] ) if __name__ == "__main__": demo.launch(share=True)