Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torch.nn.functional as F | |
| from torch.utils.data import deque | |
| import random | |
| import threading | |
| import gradio as gr | |
| import time | |
| # =========================== | |
| # 1. Tiny Transformer Model | |
| # =========================== | |
| class TinyTransformer(nn.Module): | |
| def __init__(self, vocab_size=5000, d_model=128, n_heads=4, n_layers=2, max_len=128): | |
| super().__init__() | |
| self.token_emb = nn.Embedding(vocab_size, d_model) | |
| self.pos_emb = nn.Embedding(max_len, d_model) | |
| encoder_layer = nn.TransformerEncoderLayer( | |
| d_model=d_model, nhead=n_heads, dim_feedforward=256 | |
| ) | |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) | |
| self.fc = nn.Linear(d_model, vocab_size) | |
| def forward(self, tokens): | |
| B, T = tokens.shape | |
| positions = torch.arange(T).unsqueeze(0).expand(B, T).to(tokens.device) | |
| x = self.token_emb(tokens) + self.pos_emb(positions) | |
| x = self.transformer(x) | |
| logits = self.fc(x) # (B, T, vocab) | |
| return logits | |
| # =========================== | |
| # 2. Tokenizer (super simple) | |
| # =========================== | |
| class BasicTokenizer: | |
| def __init__(self, vocab_size=5000): | |
| self.vocab_size = vocab_size | |
| def encode(self, text): | |
| # extremely dumb tokenizer β split by whitespace | |
| ids = [min(abs(hash(t)) % self.vocab_size, self.vocab_size - 1) | |
| for t in text.split()] | |
| return ids | |
| def decode(self, ids): | |
| return " ".join([f"<{i}>" for i in ids]) # placeholder | |
| tokenizer = BasicTokenizer() | |
| # ====================================== | |
| # 3. Experience Memory for RL + Imitation | |
| # ====================================== | |
| experience = deque(maxlen=2000) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = TinyTransformer().to(device) | |
| optimizer = optim.Adam(model.parameters(), lr=3e-4) | |
| reward_scale = 1.0 | |
| # ====================================== | |
| # 4. Training Loop (runs in background) | |
| # ====================================== | |
| def training_thread(): | |
| while True: | |
| if len(experience) < 5: | |
| time.sleep(0.2) | |
| continue | |
| batch = random.sample(experience, 4) | |
| losses = [] | |
| for prompt, target_answer, reward in batch: | |
| encoded_in = tokenizer.encode(prompt) | |
| encoded_out = tokenizer.encode(target_answer) | |
| x = torch.tensor([encoded_in], dtype=torch.long).to(device) | |
| y = torch.tensor([encoded_out], dtype=torch.long).to(device) | |
| logits = model(x) | |
| logits = logits[:, :y.shape[1], :] | |
| ce_loss = F.cross_entropy( | |
| logits.reshape(-1, logits.size(-1)), | |
| y.reshape(-1) | |
| ) | |
| rl_loss = -reward * logits.mean() | |
| loss = ce_loss + reward_scale * rl_loss | |
| losses.append(loss) | |
| total_loss = sum(losses) / len(losses) | |
| optimizer.zero_grad() | |
| total_loss.backward() | |
| optimizer.step() | |
| time.sleep(0.1) | |
| thread = threading.Thread(target=training_thread, daemon=True) | |
| thread.start() | |
| # ====================================== | |
| # 5. Inference | |
| # ====================================== | |
| def generate(model, text, max_new_tokens=20): | |
| model.eval() | |
| ids = tokenizer.encode(text) | |
| ids = torch.tensor([ids], dtype=torch.long).to(device) | |
| for _ in range(max_new_tokens): | |
| logits = model(ids) | |
| next_token = logits[:, -1, :].argmax(dim=-1) | |
| ids = torch.cat([ids, next_token.unsqueeze(-1)], dim=1) | |
| return tokenizer.decode(ids[0].tolist()) | |
| # ====================================== | |
| # 6. Gradio App | |
| # ====================================== | |
| def chat(prompt, history): | |
| reply = generate(model, prompt) | |
| history.append((prompt, reply)) | |
| return history, reply | |
| def feedback(data): | |
| last_user, last_bot = data[-1] # get last turn | |
| fb = gr.get_state() # +1 or -1 reward | |
| experience.append((last_user, last_bot, fb)) | |
| return "Thanks for the feedback!" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# π§ agillm β Self Learning LLM") | |
| chatbox = gr.Chatbot() | |
| user_input = gr.Textbox() | |
| send = gr.Button("Send") | |
| fb_up = gr.Button("π") | |
| fb_down = gr.Button("π") | |
| state_reward = gr.State(0) | |
| send.click(chat, [user_input, chatbox], [chatbox, user_input]) | |
| fb_up.click(lambda: 1, None, state_reward).then(feedback, chatbox) | |
| fb_down.click(lambda: -1, None, state_reward).then(feedback, chatbox) | |
| demo.launch() | |