File size: 2,772 Bytes
f62dc29
e9e19db
f62dc29
e5ba9b9
f62dc29
 
9a74bcc
e6eeb28
f62dc29
 
e6eeb28
 
 
 
 
 
f62dc29
db8631a
f62dc29
 
 
 
 
 
 
c430d50
10eadd6
f62dc29
 
 
 
 
 
 
6383c22
 
1efbc24
 
6383c22
 
 
 
 
 
9a74bcc
6383c22
 
 
 
 
 
 
e6eeb28
a5f28af
6383c22
 
 
 
 
 
9a74bcc
 
d420af1
 
e11f585
 
 
 
 
d598116
 
 
d420af1
e11f585
e9e19db
 
d420af1
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import torch
import gradio as gr
from nanochat.gpt import GPT, GPTConfig
from nanochat.tokenizer import RustBPETokenizer

# --- System Initialization ---
TOKENIZER_DIR = "."
tokenizer = RustBPETokenizer.from_directory(TOKENIZER_DIR)

# Map Special Tokens
tokenizer.bos_token_id = tokenizer.enc.encode_single_token("<|bos|>")
tokenizer.user_start_id = tokenizer.enc.encode_single_token("<|user_start|>")
tokenizer.user_end_id = tokenizer.enc.encode_single_token("<|user_end|>")
tokenizer.assistant_start_id = tokenizer.enc.encode_single_token("<|assistant_start|>")
tokenizer.assistant_end_id = tokenizer.enc.encode_single_token("<|assistant_end|>")

# Model Setup
config = GPTConfig(
    vocab_size=32768, 
    n_layer=12, 
    n_head=6, 
    n_embd=768, 
    sequence_len=2048
)

model = GPT(config)
print("Loading model weights...")
state_dict = torch.load("model_000971.pt", map_location="cpu")
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict, strict=False)
model.eval()

def predict(message, history):
    try:
        tokens = [tokenizer.bos_token_id]
        user_content = str(message).strip()
        tokens.extend([tokenizer.user_start_id] + tokenizer.encode(user_content) + [tokenizer.user_end_id])
        tokens.append(tokenizer.assistant_start_id)
        
        with torch.no_grad():
            output = model.generate(
                tokens, 
                max_tokens=512, 
                temperature=0.75, 
                top_k=40
            )
            
            generated_text = ""
            for token in output:
                token_id = token if isinstance(token, int) else token.item()
                char = tokenizer.decode([token_id])
                
                if any(tag in char for tag in ["<|assistant_end|>", "<|end|>", "<|user_start|>"]):
                    break
                    
                generated_text += char
                yield generated_text.strip()

    except Exception as e:
        yield f"⚠️ System Error: {str(e)}"

# --- UI Customization for Gradio 6.0 ---
with gr.Blocks() as demo:
    gr.ChatInterface(
        fn=predict, 
        title="⚡ SimpleAI-259M",
        description="**Fast. Focused. Simple.** A lightweight general intelligence model optimized for reasoning and logic.",
        examples=[
            "Explain neural network?",
            "Write a python function to calculate the area of a circle.",
            "Why is the sky blue?"
        ]
    )

if __name__ == "__main__":
    # Moved 'theme' here as requested by the Gradio 6.0 Warning
    demo.launch(
        server_name="0.0.0.0", 
        server_port=7860,
        theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate")
    )