Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request | |
| from pydantic import BaseModel | |
| from main import load_texts | |
| from tokenizer import SimpleTokenizer | |
| from transformer import Classifier | |
| from constants import block_size | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import uvicorn | |
| import torch | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| block_size = 1024 | |
| model = None | |
| tokenizer = None | |
| def initialize(): | |
| global model, tokenizer | |
| if not tokenizer: | |
| tokenizer = SimpleTokenizer() | |
| print('start tokenizer') | |
| for text in load_texts('train.tsv'): | |
| tokenizer.update_vocab(text.split('\t', 1)[1]) | |
| print(tokenizer.vocab) | |
| print('finish tokenizer, vocab size is: ', tokenizer.vocab_size) | |
| if not model: | |
| model = Classifier(tokenizer.vocab_size) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print('loading model') | |
| model.load_state_dict(torch.load('all_pres_classifier_model_dict.pth', map_location=device)) | |
| print('finished loading model') | |
| model.to(device) | |
| model.eval() | |
| class TextInput(BaseModel): | |
| text: str | |
| def home(): | |
| return {"message": "Welcome! The server is running! Send a POST request to /predict please."} | |
| def predict(request: TextInput): | |
| global model, tokenizer | |
| if model is None or tokenizer is None: | |
| initialize() | |
| text = request.text | |
| # Get the text from the POST request body | |
| # Perform inference | |
| with torch.no_grad(): | |
| wordids = tokenizer.encode(text) | |
| padded_sentence = wordids[:block_size] + [0] * (block_size - len(wordids)) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| input_tensor = torch.tensor(padded_sentence, dtype=torch.long).unsqueeze(0).to(device) | |
| print('input', input_tensor) | |
| output, _ = model(input_tensor) | |
| _, predicted = torch.max(output.data, 1) | |
| return {"predicted": predicted.tolist()} | |