| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import numpy as np |
| | import math |
| | import gradio as gr |
| |
|
| | |
| | with open("dataset.txt", "r", encoding="utf-8") as f: |
| | text = f.read().lower() |
| |
|
| | chars = sorted(list(set(text))) |
| | vocab_size = len(chars) |
| | stoi = {ch:i for i,ch in enumerate(chars)} |
| | itos = {i:ch for i,ch in enumerate(chars)} |
| |
|
| | def encode(s): return [stoi.get(c, 0) for c in s] |
| | def decode(l): return "".join([itos[i] for i in l]) |
| |
|
| | |
| | class GPTBlock(nn.Module): |
| | def __init__(self, d_model, nhead, dim_feedforward, dropout): |
| | super().__init__() |
| | self.attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
| | self.ff = nn.Sequential( |
| | nn.Linear(d_model, dim_feedforward), |
| | nn.GELU(), |
| | nn.Linear(dim_feedforward, d_model), |
| | nn.Dropout(dropout), |
| | ) |
| | self.ln1 = nn.LayerNorm(d_model) |
| | self.ln2 = nn.LayerNorm(d_model) |
| |
|
| | def forward(self, x, mask=None): |
| | attn_out, _ = self.attn(x, x, x, attn_mask=mask) |
| | x = self.ln1(x + attn_out) |
| | ff_out = self.ff(x) |
| | x = self.ln2(x + ff_out) |
| | return x |
| |
|
| | class GPTModel(nn.Module): |
| | def __init__(self, vocab_size, d_model=128, nhead=8, num_layers=4, dim_feedforward=512, max_len=5000, dropout=0.1): |
| | super().__init__() |
| | self.token_emb = nn.Embedding(vocab_size, d_model) |
| | self.pos_emb = nn.Parameter(torch.zeros(1, max_len, d_model)) |
| | self.blocks = nn.ModuleList([GPTBlock(d_model, nhead, dim_feedforward, dropout) for _ in range(num_layers)]) |
| | self.ln_f = nn.LayerNorm(d_model) |
| | self.head = nn.Linear(d_model, vocab_size) |
| |
|
| | def forward(self, x): |
| | seq_len = x.size(1) |
| | token_embeddings = self.token_emb(x) |
| | pos_embeddings = self.pos_emb[:, :seq_len, :] |
| | x = token_embeddings + pos_embeddings |
| | x = x.transpose(0, 1) |
| |
|
| | |
| | mask = torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1).to(x.device) |
| |
|
| | for block in self.blocks: |
| | x = block(x, mask) |
| |
|
| | x = x.transpose(0, 1) |
| | x = self.ln_f(x) |
| | logits = self.head(x) |
| | return logits |
| |
|
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | model = GPTModel(vocab_size).to(device) |
| | optimizer = torch.optim.Adam(model.parameters(), lr=0.005) |
| | criterion = nn.CrossEntropyLoss() |
| |
|
| | seq_len = 25 |
| | batch_size = 1 |
| | epochs = 300 |
| |
|
| | data_tensor = torch.tensor(encode(text), dtype=torch.long) |
| |
|
| | for epoch in range(epochs): |
| | model.train() |
| | idx = np.random.randint(0, len(data_tensor) - seq_len - 1) |
| | chunk = data_tensor[idx:idx+seq_len+1].unsqueeze(0).to(device) |
| | input_seq = chunk[:, :-1] |
| | target_seq = chunk[:, 1:] |
| |
|
| | optimizer.zero_grad() |
| | logits = model(input_seq) |
| | loss = criterion(logits.view(-1, vocab_size), target_seq.view(-1)) |
| | loss.backward() |
| | optimizer.step() |
| |
|
| | if epoch % 50 == 0: |
| | print(f"Epoch {epoch}, Loss: {loss.item():.4f}") |
| |
|
| | |
| | def generate_text(model, seed, max_len=100): |
| | model.eval() |
| | input_ids = torch.tensor(encode(seed), dtype=torch.long).unsqueeze(0).to(device) |
| | generated = seed |
| |
|
| | with torch.no_grad(): |
| | for _ in range(max_len): |
| | logits = model(input_ids) |
| | probs = F.softmax(logits[0, -1], dim=-1).cpu().numpy() |
| | next_id = np.random.choice(len(probs), p=probs) |
| | generated += itos[next_id] |
| | next_token = torch.tensor([[next_id]], device=device) |
| | input_ids = torch.cat([input_ids, next_token], dim=1) |
| |
|
| | return generated |
| |
|
| | |
| | def chat_with_ai(inp): |
| | return generate_text(model, inp, max_len=100)[len(inp):] |
| |
|
| | import gradio as gr |
| | iface = gr.Interface(fn=chat_with_ai, |
| | inputs=gr.Textbox(lines=1, placeholder="Ketik chat kamu..."), |
| | outputs="text", |
| | title="Chat AI Transformer GPT Style", |
| | description="Chat AI pake model Transformer GPT-style sederhana") |
| |
|
| | iface.launch() |
| |
|