App / app.py
wiped's picture
Update app.py
e543593 verified
Raw
History Blame Contribute Delete
4.51 kB
import torch
import torch.nn as nn
import torch.optim as optim
import streamlit as st
import sqlite3
import re
import urllib.request
import urllib.parse
import os
import json
import time
from datetime import datetime, timedelta
# ===== config (50 Layers / 512 Dim) =====
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "bata_ultimate_brain.pth"
EMBED_DIM = 512
N_LAYERS = 50
N_HEADS = 8
class GPTBlock(nn.Module):
def __init__(self, dim, head):
super().__init__()
self.attn = nn.MultiheadAttention(dim, head, batch_first=True)
self.norm1 = nn.LayerNorm(dim)
self.ff = nn.Sequential(nn.Linear(dim, 4 * dim), nn.GELU(), nn.Linear(4 * dim, dim))
self.norm2 = nn.LayerNorm(dim)
def forward(self, x):
mask = torch.triu(torch.ones(x.size(1), x.size(1)), 1).bool().to(DEVICE)
attn_out, _ = self.attn(x, x, x, attn_mask=mask)
x = self.norm1(x + attn_out)
return self.norm2(x + self.ff(x))
class BataGPT(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.embed = nn.Embedding(vocab_size, EMBED_DIM)
self.pos_emb = nn.Parameter(torch.zeros(1, 1024, EMBED_DIM))
self.blocks = nn.Sequential(*[GPTBlock(EMBED_DIM, N_HEADS) for _ in range(N_LAYERS)])
self.to_logits = nn.Linear(EMBED_DIM, vocab_size)
def forward(self, x):
x = self.embed(x) + self.pos_emb[:, :x.size(1), :]
return self.to_logits(self.blocks(x)[:, -1, :])
class BataAI:
def __init__(self):
self.init_vocab()
self.model = BataGPT(len(st.session_state.i2w)).to(DEVICE)
self.optimizer = optim.AdamW(self.model.parameters(), lr=1e-5)
self.load_brain()
def init_vocab(self):
if 'w2i' not in st.session_state:
st.session_state.w2i = {"<PAD>": 0, "<UNK>": 1}
st.session_state.i2w = ["<PAD>", "<UNK>"]
for c in "กขคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรลวศษสหฬอฮะาิีึืุูเแโใไำะัำาำ -_=:;\\/":
self.add_char(c)
def add_char(self, c):
if c not in st.session_state.w2i:
st.session_state.w2i[c] = len(st.session_state.i2w)
st.session_state.i2w.append(c)
def load_brain(self):
if os.path.exists(MODEL_PATH):
ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
st.session_state.w2i, st.session_state.i2w = ckpt['w2i'], ckpt['i2w']
self.model = BataGPT(len(st.session_state.i2w)).to(DEVICE)
self.model.load_state_dict(ckpt['state'])
def save_brain(self):
torch.save({'state': self.model.state_dict(), 'w2i': st.session_state.w2i, 'i2w': st.session_state.i2w}, MODEL_PATH)
def parse_and_train_txt(self, content):
"""ระบบแกะรหัส Leaning Mode (แก้อาการโง่!)"""
lines = content.split(';') # แยกตาม u1, u2...
for line in lines:
if not line.strip(): continue
for char in line: self.add_char(char)
# --- Leaning Mode: ดึงพลังจากเน็ตมาปน ---
if "Leaning" in line:
m = re.search(r'u\d+ :(.*?) -', line)
if m:
keyword = m.group(1).strip()
try:
url = f"https://api.duckduckgo.com/?q={urllib.parse.quote(keyword)}&format=json"
with urllib.request.urlopen(url) as res:
data = json.loads(res.read().decode())
net_text = data.get('AbstractText', '')
line += " " + net_text # เอาข้อมูลเน็ตมาต่อท้ายเพื่อให้ AI เรียนรู้บริบทเพิ่ม
except: pass
# Training Process
tokens = [st.session_state.w2i.get(c, 1) for c in line]
if len(tokens) > 1:
x = torch.tensor([tokens[:-1]]).to(DEVICE)
self.model.train()
self.optimizer.zero_grad()
self.model(x).sum().backward()
self.optimizer.step()
self.save_brain()
# (ส่วน UI Streamlit ด้านล่างใช้เหมือน v11 ได้เลยค่ะแม่)