File size: 5,610 Bytes
53663e6 ac1d593 53663e6 ac1d593 53663e6 ac1d593 53663e6 ac1d593 53663e6 ac1d593 53663e6 ac1d593 53663e6 ac1d593 53663e6 ac1d593 53663e6 ac1d593 53663e6 ac1d593 53663e6 ac1d593 53663e6 ac1d593 53663e6 ac1d593 53663e6 ac1d593 53663e6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | import torch
import torch.nn as nn
import unicodedata
import os
import gradio as gr
from transformers import PreTrainedTokenizerFast, PretrainedConfig, PreTrainedModel
from tokenizers import decoders
# 1. Re-define the Architecture Classes (identical to the training/test phase)
class IsaiConfig(PretrainedConfig):
model_type = "isai"
def __init__(self, vocab_size=32000, hidden_size=1024, intermediate_size=2816, num_hidden_layers=24, num_attention_heads=16, num_key_value_heads=16, hidden_act="silu", max_position_embeddings=2048, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, pad_token_id=0, bos_token_id=1, eos_token_id=2, **kwargs):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.max_position_embeddings = max_position_embeddings
self.rms_norm_eps = rms_norm_eps
class IsaiRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class IsaiForCausalLM(PreTrainedModel):
config_class = IsaiConfig
def __init__(self, config):
super().__init__(config)
self.model = nn.ModuleDict({
"embed_tokens": nn.Embedding(config.vocab_size, config.hidden_size),
"layers": nn.ModuleList([nn.ModuleDict({
"input_layernorm": IsaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps),
"post_attention_layernorm": IsaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps),
"self_attn": nn.Linear(config.hidden_size, config.hidden_size, bias=False),
"mlp": nn.ModuleDict({
"gate_proj": nn.Linear(config.hidden_size, config.intermediate_size, bias=False),
"up_proj": nn.Linear(config.hidden_size, config.intermediate_size, bias=False),
"down_proj": nn.Linear(config.intermediate_size, config.hidden_size, bias=False),
})
}) for _ in range(config.num_hidden_layers)]),
"norm": IsaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
})
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def forward(self, input_ids=None, **kwargs):
hidden_states = self.model.embed_tokens(input_ids)
for layer in self.model.layers:
h = layer.input_layernorm(hidden_states)
hidden_states = hidden_states + layer.self_attn(h)
h = layer.post_attention_layernorm(hidden_states)
hidden_states = hidden_states + layer.mlp.down_proj(nn.functional.silu(layer.mlp.gate_proj(h)) * layer.mlp.up_proj(h))
logits = self.lm_head(self.model.norm(hidden_states))
return logits
# 2. Load Model and Tokenizer
model_dir = "models/isai-v4.2"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_dir)
tokenizer._tokenizer.decoder = decoders.ByteLevel() # Critical for jaso restoration
config = IsaiConfig.from_pretrained(model_dir)
model = IsaiForCausalLM(config).to(device)
# Prioritize safetensors
weights_path = os.path.join(model_dir, "model.safetensors")
if os.path.exists(weights_path):
from safetensors.torch import load_file
model.load_state_dict(load_file(weights_path))
else:
model.load_state_dict(torch.load(os.path.join(model_dir, "pytorch_model.bin"), map_location=device))
model.eval()
# 3. Define the Prediction Logic with Jaso Processing
def predict(message, history):
# A. NFD Decomposition (Input)
decomposed_input = unicodedata.normalize('NFD', message)
input_ids = tokenizer.encode(decomposed_input, return_tensors="pt").to(device)
current_ids = input_ids
max_new_tokens = 50
# B. Generate tokens
for _ in range(max_new_tokens):
with torch.no_grad():
logits = model(current_ids)
next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0)
current_ids = torch.cat([current_ids, next_token], dim=-1)
if next_token.item() == tokenizer.eos_token_id:
break
# C. Decode and NFC Recomposition (Output)
# Only decode the generated part
generated_tokens = current_ids[0][input_ids.shape[1]:]
raw_response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
final_response = unicodedata.normalize('NFC', raw_response)
return final_response
# 4. Create and Launch Gradio Interface
demo = gr.ChatInterface(
fn=predict,
title="isai-v4.2 Jaso-Level Chat",
description="μμ λ¨μ(NFD)λ‘ μν΅νλ μ΄μν μΌμ λν λͺ¨λΈμ
λλ€. μ
λ ₯μ μλμΌλ‘ λΆν΄λκ³ μΆλ ₯μ λ€μ νκΈλ‘ μ‘°ν©λ©λλ€.",
examples=["μλ
? λ°κ°μ.", "μ€λ λ μ¨κ° μ΄λ?", "λμ μ΄λ¦μ λμΌ?"]
)
if __name__ == "__main__":
demo.launch(share=True) |