Spaces:
Sleeping
Sleeping
| import fastapi | |
| import enum | |
| import numpy | |
| import pydantic | |
| import typing | |
| import jwt | |
| import jwt.algorithms | |
| import os | |
| class ModelName(str, enum.Enum): | |
| alexnet = "alexnet" | |
| resnet = "resnet" | |
| lenet = "lenet" | |
| app = fastapi.FastAPI() | |
| print(jwt.algorithms.get_default_algorithms()) | |
| array = numpy.array(range(1000)) | |
| def check_valid_email(email: str): | |
| #initial_query = fastapi.Query(min_length = 3, max_length = 32) | |
| if "@" in email: | |
| return email | |
| raise ValueError("Email is invalid") | |
| def index(): | |
| return {"name" : "First Data"} | |
| async def read_item(item_id: int): | |
| return {"item_id": item_id} | |
| async def read_user_me(): | |
| return {"user_id": "the current user"} | |
| async def read_user(user_id: str): | |
| return {"user_id": user_id} | |
| async def get_model(model_name: ModelName): | |
| if model_name == ModelName.alexnet: | |
| return {"model_name": model_name, "message": "Deep Learning FTW!"} | |
| if model_name.value == "lenet": | |
| return {"model_name": model_name, "message": "LeCNN all the images"} | |
| return {"model_name": model_name, "message": "Have some residuals"} | |
| async def read_file(file_path: str): | |
| return {"file_path": file_path} | |
| async def get_array(start: int, skip: int = 10): | |
| print(array[start : start + skip].tolist()) | |
| return {"array": array[start : start + skip].tolist()} | |
| async def test(test : bool): | |
| return {"test": test} | |
| async def read_item(channel_owner: str, product_id: int, q: str = None): | |
| return {"product_owner": channel_owner, "product_id": product_id, "q": q} | |
| class User(pydantic.BaseModel): | |
| username: typing.Annotated[str, fastapi.Query(min_length = 3, max_length = 32)] | |
| display_name: str | |
| email: typing.Annotated[str, pydantic.AfterValidator(check_valid_email)] | |
| async def create_user(user: User): | |
| return {"user": user} | |
| async def get_followers(channel_owner: str, limit: typing.Annotated[int, fastapi.Path(title = "The number of followers to retrieve")]): | |
| return {"channel_owner": channel_owner, "limit": limit} | |
| def encode_data(user: User): | |
| encryption_key = os.environ.get("ENCRYPTION_KEY") | |
| encryption_algorithm = os.environ.get("ENCRYPTION_ALGORITHM") | |
| if encryption_key is None: | |
| raise ValueError("No ENCRYPTION_KEY set for the application") | |
| if encryption_algorithm is None: | |
| raise ValueError("No ENCRYPTION_ALGORITHM set for the application") | |
| print(f"{encryption_key} and {encryption_algorithm}") | |
| print(jwt.get_algorithm_by_name(encryption_algorithm)) | |
| payload = { | |
| "username": user.username, | |
| "display_name": user.display_name, | |
| "email": user.email | |
| } | |
| payload = jwt.encode(payload, encryption_key, algorithm = encryption_algorithm) | |
| print(payload) | |
| return payload | |
| def decode_data(token: str): | |
| encryption_key = os.environ.get("ENCRYPTION_KEY") | |
| encryption_algorithm = os.environ.get("ENCRYPTION_ALGORITHM") | |
| if encryption_key is None: | |
| raise ValueError("No ENCRYPTION_KEY set for the application") | |
| if encryption_algorithm is None: | |
| raise ValueError("No ENCRYPTION_ALGORITHM set for the application") | |
| try: | |
| payload = jwt.decode(token, encryption_key, algorithms = [encryption_algorithm]) | |
| print(payload) | |
| return payload | |
| except jwt.JWTError: | |
| return None | |
| async def login(user: User): | |
| payload = encode_data(user) | |
| print(decode_data(payload)) | |
| return {"token": payload} | |
| async def verify(token: str): | |
| return decode_data(token) | |
| # @app.get("") |