File size: 4,508 Bytes
9ad0895
 
 
 
 
 
 
 
68abdea
aaa7254
68abdea
bb8aca6
9ad0895
e543593
9ad0895
e543593
 
 
68abdea
9ad0895
68abdea
 
2a74cce
68abdea
 
e543593
68abdea
9ad0895
68abdea
56312be
 
e543593
fdf5843
 
68abdea
fdf5843
68abdea
 
 
 
fdf5843
68abdea
e543593
9ad0895
 
68abdea
 
 
e543593
68abdea
 
 
9ad0895
68abdea
 
e543593
 
68abdea
e543593
 
 
 
68abdea
 
 
e543593
 
68abdea
e543593
68abdea
e543593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68abdea
1aeb9f2
e543593
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torch.nn as nn
import torch.optim as optim
import streamlit as st
import sqlite3
import re
import urllib.request
import urllib.parse
import os
import json
import time
from datetime import datetime, timedelta

# ===== config (50 Layers / 512 Dim) =====
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "bata_ultimate_brain.pth"
EMBED_DIM = 512
N_LAYERS = 50 
N_HEADS = 8

class GPTBlock(nn.Module):
    def __init__(self, dim, head):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, head, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.ff = nn.Sequential(nn.Linear(dim, 4 * dim), nn.GELU(), nn.Linear(4 * dim, dim))
        self.norm2 = nn.LayerNorm(dim)
    def forward(self, x):
        mask = torch.triu(torch.ones(x.size(1), x.size(1)), 1).bool().to(DEVICE)
        attn_out, _ = self.attn(x, x, x, attn_mask=mask)
        x = self.norm1(x + attn_out)
        return self.norm2(x + self.ff(x))

class BataGPT(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, EMBED_DIM)
        self.pos_emb = nn.Parameter(torch.zeros(1, 1024, EMBED_DIM))
        self.blocks = nn.Sequential(*[GPTBlock(EMBED_DIM, N_HEADS) for _ in range(N_LAYERS)])
        self.to_logits = nn.Linear(EMBED_DIM, vocab_size)
    def forward(self, x):
        x = self.embed(x) + self.pos_emb[:, :x.size(1), :]
        return self.to_logits(self.blocks(x)[:, -1, :])

class BataAI:
    def __init__(self):
        self.init_vocab()
        self.model = BataGPT(len(st.session_state.i2w)).to(DEVICE)
        self.optimizer = optim.AdamW(self.model.parameters(), lr=1e-5)
        self.load_brain()

    def init_vocab(self):
        if 'w2i' not in st.session_state:
            st.session_state.w2i = {"<PAD>": 0, "<UNK>": 1}
            st.session_state.i2w = ["<PAD>", "<UNK>"]
            for c in "กขคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรลวศษสหฬอฮะาิีึืุูเแโใไำะัำาำ -_=:;\\/": 
                self.add_char(c)

    def add_char(self, c):
        if c not in st.session_state.w2i:
            st.session_state.w2i[c] = len(st.session_state.i2w)
            st.session_state.i2w.append(c)

    def load_brain(self):
        if os.path.exists(MODEL_PATH):
            ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
            st.session_state.w2i, st.session_state.i2w = ckpt['w2i'], ckpt['i2w']
            self.model = BataGPT(len(st.session_state.i2w)).to(DEVICE)
            self.model.load_state_dict(ckpt['state'])

    def save_brain(self):
        torch.save({'state': self.model.state_dict(), 'w2i': st.session_state.w2i, 'i2w': st.session_state.i2w}, MODEL_PATH)

    def parse_and_train_txt(self, content):
        """ระบบแกะรหัส Leaning Mode (แก้อาการโง่!)"""
        lines = content.split(';') # แยกตาม u1, u2...
        for line in lines:
            if not line.strip(): continue
            for char in line: self.add_char(char)
            
            # --- Leaning Mode: ดึงพลังจากเน็ตมาปน ---
            if "Leaning" in line:
                m = re.search(r'u\d+ :(.*?) -', line)
                if m:
                    keyword = m.group(1).strip()
                    try:
                        url = f"https://api.duckduckgo.com/?q={urllib.parse.quote(keyword)}&format=json"
                        with urllib.request.urlopen(url) as res:
                            data = json.loads(res.read().decode())
                            net_text = data.get('AbstractText', '')
                            line += " " + net_text # เอาข้อมูลเน็ตมาต่อท้ายเพื่อให้ AI เรียนรู้บริบทเพิ่ม
                    except: pass
            
            # Training Process
            tokens = [st.session_state.w2i.get(c, 1) for c in line]
            if len(tokens) > 1:
                x = torch.tensor([tokens[:-1]]).to(DEVICE)
                self.model.train()
                self.optimizer.zero_grad()
                self.model(x).sum().backward()
                self.optimizer.step()
        self.save_brain()

# (ส่วน UI Streamlit ด้านล่างใช้เหมือน v11 ได้เลยค่ะแม่)