| from fastapi import FastAPI, Request, Depends, HTTPException, status | |
| from fastapi.security import OAuth2PasswordBearer | |
| from QuoteGenerator import QuoteGenerator | |
| from typing import Union | |
| from pydantic import BaseModel | |
| import time | |
| import os | |
| # API to key to validate the Referer | |
| API_KEY = os.getenv('API_KEY') | |
| oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") | |
| # Function to check of the incoming API call is from valid host or not | |
| def api_key_auth(api_key:str = Depends(oauth2_scheme)): | |
| if api_key != API_KEY: | |
| raise HTTPException( | |
| status_code = status.HTTP_401_UNAUTHORIZED, | |
| detail="Forbidden Access" | |
| ) | |
| class QuoteRequest(BaseModel): | |
| tags: Union[None, str] = None | |
| do_sample: bool = False | |
| max_new_tokens: int = 16 | |
| num_beams: int = 1 | |
| top_k: int = 50 | |
| top_p: float = 1.0 | |
| temperature: float = 1.0 | |
| app = FastAPI() | |
| #Middleware to note time | |
| async def note_response_time(request: Request, call_next): | |
| start_time = time.time() | |
| response = await call_next(request) | |
| process_time = time.time() | |
| print(f'Time taken = {process_time-start_time:.1f}s') | |
| return response | |
| quote_generator = QuoteGenerator() | |
| quote_generator.load_generator() | |
| def root(request: QuoteRequest): | |
| print("Incoming request\n", request.__dict__) | |
| return {"quote": "<bot>:A beautiful quote generated by bot"} | |
| def generate_quote(req: QuoteRequest): | |
| print("\nIncoming request \n", req.__dict__, end='\n\n') | |
| generated_quote_oup = quote_generator.generate_quote( | |
| tags = req.tags, | |
| max_new_tokens = req.max_new_tokens, | |
| num_beams = req.num_beams, | |
| temperature = req.temperature, | |
| top_k = req.top_k, | |
| top_p = req.top_p, | |
| do_sample = req.do_sample | |
| ) | |
| return {'quote': generated_quote_oup} |