from fastapi import FastAPI from pydantic import BaseModel from tokenizer import SimpleTokenizer from transformer import Classifier from constants import block_size from fastapi.middleware.cors import CORSMiddleware import uvicorn import torch import pickle app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) model = None tokenizer = SimpleTokenizer() pres_dict = {} # load in pres dicts with open('speechesdataset/pres_dict.pkl', 'rb') as file1: reversed_dict = pickle.load(file1) pres_dict = {value: key for key, value in reversed_dict.items()} with open('speechesdataset/tokenizer_stoi.pkl', 'rb') as file: tokenizer.stoi = pickle.load(file) with open('speechesdataset/tokenizer_itos.pkl', 'rb') as file: tokenizer.itos = pickle.load(file) with open('speechesdataset/tokenizer_vocab.pkl', 'rb') as file: tokenizer.vocab = pickle.load(file) tokenizer.vocab_size = len(tokenizer.stoi) # load in model model = Classifier(tokenizer.vocab_size) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print('loading model') model.load_state_dict(torch.load('speechesdataset/checkpoint.pt', map_location=device)) print('finished loading model') model.to(device) model.eval() class TextInput(BaseModel): text: str @app.get("/") def home(): return {"message": "Welcome! The server is running! Send a POST request to /predict please with \"text\" as argument."} @app.post("/predict") def predict(request: TextInput): text = request.text # Get the text from the POST request body # Perform inference with torch.no_grad(): wordids = tokenizer.encode(text) padded_sentence = wordids[:block_size] + [0] * (block_size - len(wordids)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") input_tensor = torch.tensor(padded_sentence, dtype=torch.long).unsqueeze(0).to(device) print('input', input_tensor) output, _ = model(input_tensor) _, predicted = torch.max(output.data, 1) return {"predicted": pres_dict[predicted.tolist()[0]]}