import torch import torch.nn as nn import torch.optim as optim import pandas as pd from torch.utils.data import Dataset, DataLoader from flask import Flask, request, jsonify, Response, stream_with_context from sklearn.model_selection import train_test_split import os import time import json # Set PyTorch to use all available CPU threads torch.set_num_threads(os.cpu_count()) torch.set_num_interop_threads(os.cpu_count()) url = "https://drive.google.com/uc?id=1RCZShB5ohy1HdU-mogcP16TbeVv9txpY" df = pd.read_csv(url) # Tokenizer class ScratchTokenizer: def __init__(self): self.word2idx = {"": 0, "< SOS >": 1, "": 2, "": 3} self.idx2word = {0: "", 1: "< SOS >", 2: "", 3: ""} self.vocab_size = 4 def build_vocab(self, texts): for text in texts: for word in text.split(): if word not in self.word2idx: self.word2idx[word] = self.vocab_size self.idx2word[self.vocab_size] = word self.vocab_size += 1 def encode(self, text, max_len=200): tokens = [self.word2idx.get(word, 3) for word in text.split()] tokens = [1] + tokens[:max_len - 2] + [2] return tokens + [0] * (max_len - len(tokens)) def decode(self, tokens): return " ".join([self.idx2word.get(idx, "") for idx in tokens if idx > 0]) # Train-Test Split train_data, test_data = train_test_split(df, test_size=0.2, random_state=42) # Initialize Tokenizer tokenizer = ScratchTokenizer() tokenizer.build_vocab(train_data["instruction"].tolist() + train_data["response"].tolist()) # Model class GPTModel(nn.Module): def __init__(self, vocab_size, embed_size=256, num_heads=8, num_layers=6, max_len=200): super(GPTModel, self).__init__() self.embedding = nn.Embedding(vocab_size, embed_size) self.pos_embedding = nn.Parameter(torch.randn(1, max_len, embed_size)) self.transformer = nn.TransformerDecoder( nn.TransformerDecoderLayer(d_model=embed_size, nhead=num_heads), num_layers=num_layers ) self.fc_out = nn.Linear(embed_size, vocab_size) def forward(self, src, tgt): src_emb = self.embedding(src) + self.pos_embedding[:, :src.size(1), :] tgt_emb = self.embedding(tgt) + self.pos_embedding[:, :tgt.size(1), :] tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device) output = self.transformer(tgt_emb.permute(1, 0, 2), src_emb.permute(1, 0, 2), tgt_mask=tgt_mask) return self.fc_out(output.permute(1, 0, 2)) # Load model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = GPTModel(tokenizer.vocab_size).to(device) def load_model(model, path="gpt_model.pth"): if os.path.exists(path): model.load_state_dict(torch.load(path, map_location=device, weights_only=True)) model.eval() print("Model loaded successfully.") else: print("Model file not found!") return model def generate_response_stream(model, query, max_length=200): model.eval() # Pre-encode the query once src_tokens = tokenizer.encode(query) src = torch.tensor(src_tokens).unsqueeze(0).to(device) tgt = torch.tensor([[1]], dtype=torch.long).to(device) # < SOS > with torch.no_grad(): for step in range(max_length): # Forward pass output = model(src, tgt) # Get next token more efficiently logits = output[:, -1, :] next_token = torch.argmax(logits, dim=-1, keepdim=True) # Check for EOS early if next_token.item() == 2: # break # Concatenate token tgt = torch.cat([tgt, next_token], dim=1) # Get the current word current_word = tokenizer.idx2word.get(next_token.item(), "") if current_word not in ["", "", "< SOS >"]: yield current_word + " " # Flask App app = Flask(__name__) @app.route("/") def home(): return {"message": "Streaming Transformer-based Response Generator API is running!"} @app.route("/intent") def intents(): return jsonify({"intents": list(set(df['intent'].dropna()))}) @app.route("/query", methods=["POST"]) def query_model(): data = request.get_json() query = data.get("query", "") if not query: return jsonify({"error": "Query cannot be empty"}), 400 def generate(): start = time.time() word_count = 0 for word in generate_response_stream(model, query): word_count += 1 response_data = { "word": word.strip(), "timestamp": time.time() - start, "word_count": word_count } yield f"data: {json.dumps(response_data)}\n\n" return Response( stream_with_context(generate()), mimetype='text/event-stream', headers={ 'Cache-Control': 'no-cache', 'Connection': 'keep-alive' } ) if __name__ == "__main__": # Load model model = load_model(model) # Run Flask with optimizations app.run( host="0.0.0.0", port=7860, threaded=True, debug=False )