| | import torch |
| | import torch.nn as nn |
| | from nltk.tokenize import word_tokenize |
| | import nltk |
| | from flask import Flask, request, jsonify |
| |
|
| | nltk.download('punkt') |
| |
|
| | |
| | |
| | |
| |
|
| | class TransformerModel(nn.Module): |
| | def __init__(self, vocab_size, n_embd=512, n_head=16, n_layer=10, block_size=256): |
| | super().__init__() |
| | self.block_size = block_size |
| |
|
| | self.token_embedding_table = nn.Embedding(vocab_size, n_embd) |
| | self.position_embedding_table = nn.Embedding(block_size, n_embd) |
| |
|
| | self.layers = nn.ModuleList([ |
| | nn.TransformerEncoderLayer( |
| | d_model=n_embd, |
| | nhead=n_head, |
| | dim_feedforward=n_embd * 4, |
| | dropout=0.1, |
| | activation="gelu", |
| | batch_first=True |
| | ) for _ in range(n_layer) |
| | ]) |
| |
|
| | self.ln_f = nn.LayerNorm(n_embd) |
| | self.head = nn.Linear(n_embd, vocab_size) |
| |
|
| | def forward(self, idx): |
| | B, T = idx.shape |
| |
|
| | tok_emb = self.token_embedding_table(idx) |
| | pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device)) |
| | x = tok_emb + pos_emb |
| |
|
| | for layer in self.layers: |
| | x = layer(x) |
| |
|
| | x = self.ln_f(x) |
| | logits = self.head(x) |
| |
|
| | return logits |
| |
|
| | |
| | |
| | |
| |
|
| | class WordTokenizer: |
| | def __init__(self, vocab_path): |
| | import json |
| | with open(vocab_path, "r", encoding="utf-8") as f: |
| | obj = json.load(f) |
| | self.word_to_id = obj["word_to_id"] |
| | self.id_to_word = {int(v): k for k, v in self.word_to_id.items()} |
| |
|
| | def encode(self, text): |
| | words = word_tokenize(text.lower()) |
| | return [self.word_to_id.get(w, self.word_to_id["<unk>"]) for w in words] |
| |
|
| | def decode(self, ids): |
| | return " ".join([self.id_to_word.get(i, "<unk>") for i in ids]) |
| |
|
| | |
| | |
| | |
| |
|
| | tokenizer = WordTokenizer("vocab.json") |
| | VOCAB_SIZE = len(tokenizer.word_to_id) |
| |
|
| | model = TransformerModel(VOCAB_SIZE) |
| | model.load_state_dict(torch.load("pytorch_model.bin", map_location="cpu")) |
| | model.eval() |
| |
|
| | app = Flask(__name__) |
| |
|
| | |
| | |
| | |
| |
|
| | def generate(text, max_new_tokens=50): |
| | ids = tokenizer.encode(text) |
| | x = torch.tensor([ids], dtype=torch.long) |
| |
|
| | for _ in range(max_new_tokens): |
| | logits = model(x) |
| | last = logits[0, -1] |
| | probs = torch.softmax(last, dim=0) |
| | next_id = torch.multinomial(probs, num_samples=1) |
| | x = torch.cat([x, next_id.unsqueeze(0)], dim=1) |
| |
|
| | out = x[0].tolist() |
| | return tokenizer.decode(out) |
| |
|
| | |
| | |
| | |
| |
|
| | @app.route("/chat", methods=["POST"]) |
| | def chat_api(): |
| | data = request.get_json() |
| | user_text = data["text"] |
| | response = generate(user_text, max_new_tokens=40) |
| | return jsonify({"response": response}) |
| |
|
| | if __name__ == "__main__": |
| | app.run() |
| |
|