File size: 3,828 Bytes
f7dab62
 
 
 
e85a29f
d3b822e
e85a29f
 
f7dab62
 
e85a29f
f7dab62
e85a29f
f7dab62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e85a29f
 
 
 
f7dab62
 
 
 
e85a29f
f7dab62
e85a29f
 
f7dab62
 
d3b822e
 
 
 
f7dab62
 
 
 
 
 
 
 
 
e85a29f
f7dab62
 
d3b822e
f7dab62
 
 
 
 
 
 
 
 
 
e85a29f
 
f7dab62
e85a29f
 
 
 
 
 
 
 
 
f7dab62
 
e85a29f
d3b822e
f7dab62
e85a29f
d3b822e
f7dab62
d3b822e
e85a29f
d3b822e
8d1d58a
d3b822e
e85a29f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import torch.nn as nn
import pandas as pd
from sklearn.model_selection import train_test_split
from fastapi import FastAPI, Request
from pydantic import BaseModel
from typing import Optional
import uvicorn
import os

# --- Tokenizer ---
class ScratchTokenizer:
    def __init__(self):
        self.word2idx = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.idx2word = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.vocab_size = 4

    def build_vocab(self, texts):
        for text in texts:
            for word in text.split():
                if word not in self.word2idx:
                    self.word2idx[word] = self.vocab_size
                    self.idx2word[self.vocab_size] = word
                    self.vocab_size += 1

    def encode(self, text, max_len=200):
        tokens = [self.word2idx.get(word, 3) for word in text.split()]
        tokens = [1] + tokens[:max_len - 2] + [2]
        return tokens + [0] * (max_len - len(tokens))

    def decode(self, tokens):
        return " ".join([self.idx2word.get(idx, "<UNK>") for idx in tokens if idx > 0])

# --- Load and Prepare Data ---
url = "https://drive.google.com/uc?id=1RCZShB5ohy1HdU-mogcP16TbeVv9txpY"
df = pd.read_csv(url)
train_data, _ = train_test_split(df, test_size=0.2, random_state=42)

tokenizer = ScratchTokenizer()
tokenizer.build_vocab(train_data["instruction"].tolist() + train_data["response"].tolist())

# --- Model ---
class GPTModel(nn.Module):
    def __init__(self, vocab_size, embed_size=256, num_heads=8, num_layers=6, max_len=200):
        super(GPTModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len, embed_size))
        self.transformer = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=embed_size, nhead=num_heads),
            num_layers=num_layers
        )
        self.fc_out = nn.Linear(embed_size, vocab_size)

    def forward(self, src, tgt):
        src_emb = self.embedding(src) + self.pos_embedding[:, :src.size(1), :]
        tgt_emb = self.embedding(tgt) + self.pos_embedding[:, :tgt.size(1), :]
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
        output = self.transformer(tgt_emb.permute(1, 0, 2), src_emb.permute(1, 0, 2), tgt_mask=tgt_mask)
        return self.fc_out(output.permute(1, 0, 2))

# --- Load Model ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPTModel(tokenizer.vocab_size).to(device)

def load_model(model, path="gpt_model.pth"):
    if os.path.exists(path):
        model.load_state_dict(torch.load(path, map_location=device))
        model.eval()
        print("Model loaded successfully.")
    else:
        print("Model file not found!")

load_model(model)

# --- Inference ---
def generate_response(query, max_length=200):
    model.eval()
    with torch.no_grad():
        src = torch.tensor(tokenizer.encode(query)).unsqueeze(0).to(device)
        tgt = torch.tensor([[1]]).to(device)  # <SOS>
        for _ in range(max_length):
            output = model(src, tgt)
            next_word = output.argmax(-1)[:, -1].unsqueeze(1)
            tgt = torch.cat([tgt, next_word], dim=1)
            if next_word.item() == 2:  # <EOS>
                break
    return tokenizer.decode(tgt.squeeze(0).tolist())

# --- FastAPI App ---
app = FastAPI()

class QueryRequest(BaseModel):
    query: str

@app.get("/")
def root():
    return {"message": "Transformer-based Response Generator API is running!"}

@app.post("/query")
def query_model(data: QueryRequest):
    query = data.query.strip()
    if not query:
        return {"error": "Query cannot be empty"}
    response = generate_response(query)
    return {"query": query, "response": response}