agillm / app.py
theguywhosucks's picture
Create app.py
da8f83f verified
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import deque
import random
import threading
import gradio as gr
import time
# ===========================
# 1. Tiny Transformer Model
# ===========================
class TinyTransformer(nn.Module):
def __init__(self, vocab_size=5000, d_model=128, n_heads=4, n_layers=2, max_len=128):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_len, d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_heads, dim_feedforward=256
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, tokens):
B, T = tokens.shape
positions = torch.arange(T).unsqueeze(0).expand(B, T).to(tokens.device)
x = self.token_emb(tokens) + self.pos_emb(positions)
x = self.transformer(x)
logits = self.fc(x) # (B, T, vocab)
return logits
# ===========================
# 2. Tokenizer (super simple)
# ===========================
class BasicTokenizer:
def __init__(self, vocab_size=5000):
self.vocab_size = vocab_size
def encode(self, text):
# extremely dumb tokenizer β€” split by whitespace
ids = [min(abs(hash(t)) % self.vocab_size, self.vocab_size - 1)
for t in text.split()]
return ids
def decode(self, ids):
return " ".join([f"<{i}>" for i in ids]) # placeholder
tokenizer = BasicTokenizer()
# ======================================
# 3. Experience Memory for RL + Imitation
# ======================================
experience = deque(maxlen=2000)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = TinyTransformer().to(device)
optimizer = optim.Adam(model.parameters(), lr=3e-4)
reward_scale = 1.0
# ======================================
# 4. Training Loop (runs in background)
# ======================================
def training_thread():
while True:
if len(experience) < 5:
time.sleep(0.2)
continue
batch = random.sample(experience, 4)
losses = []
for prompt, target_answer, reward in batch:
encoded_in = tokenizer.encode(prompt)
encoded_out = tokenizer.encode(target_answer)
x = torch.tensor([encoded_in], dtype=torch.long).to(device)
y = torch.tensor([encoded_out], dtype=torch.long).to(device)
logits = model(x)
logits = logits[:, :y.shape[1], :]
ce_loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
y.reshape(-1)
)
rl_loss = -reward * logits.mean()
loss = ce_loss + reward_scale * rl_loss
losses.append(loss)
total_loss = sum(losses) / len(losses)
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
time.sleep(0.1)
thread = threading.Thread(target=training_thread, daemon=True)
thread.start()
# ======================================
# 5. Inference
# ======================================
def generate(model, text, max_new_tokens=20):
model.eval()
ids = tokenizer.encode(text)
ids = torch.tensor([ids], dtype=torch.long).to(device)
for _ in range(max_new_tokens):
logits = model(ids)
next_token = logits[:, -1, :].argmax(dim=-1)
ids = torch.cat([ids, next_token.unsqueeze(-1)], dim=1)
return tokenizer.decode(ids[0].tolist())
# ======================================
# 6. Gradio App
# ======================================
def chat(prompt, history):
reply = generate(model, prompt)
history.append((prompt, reply))
return history, reply
def feedback(data):
last_user, last_bot = data[-1] # get last turn
fb = gr.get_state() # +1 or -1 reward
experience.append((last_user, last_bot, fb))
return "Thanks for the feedback!"
with gr.Blocks() as demo:
gr.Markdown("# 🧠 agillm β€” Self Learning LLM")
chatbox = gr.Chatbot()
user_input = gr.Textbox()
send = gr.Button("Send")
fb_up = gr.Button("πŸ‘")
fb_down = gr.Button("πŸ‘Ž")
state_reward = gr.State(0)
send.click(chat, [user_input, chatbox], [chatbox, user_input])
fb_up.click(lambda: 1, None, state_reward).then(feedback, chatbox)
fb_down.click(lambda: -1, None, state_reward).then(feedback, chatbox)
demo.launch()