| from fastapi import FastAPI | |
| from fastapi.responses import HTMLResponse | |
| from pydantic import BaseModel, Extra | |
| import argparse | |
| from typing import Optional | |
| import uvicorn | |
| from model import ChallengePromptGenerator | |
| class Prompt(BaseModel, extra=Extra.allow): | |
| prompt: str | |
| seed: Optional[int] = 0 | |
| max_length: Optional[int] = 77 | |
| def get_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--port", type=int, default=10001) | |
| parser.add_argument("--netuid", type=str, default=23) | |
| parser.add_argument("--min_stake", type=int, default=100) | |
| parser.add_argument( | |
| "--chain_endpoint", | |
| type=str, | |
| default="finney", | |
| ) | |
| parser.add_argument("--disable_secure", action="store_true", default=False) | |
| args = parser.parse_args() | |
| return args | |
| class ChallengeImage: | |
| def __init__(self): | |
| self.challenge_prompt = ChallengePromptGenerator() | |
| self.app = FastAPI(title="Challenge Prompt") | |
| self.app.add_api_route("/", self.__call__, methods=["POST"]) | |
| self.app.add_api_route("/", self.serve_index, methods=["GET"]) | |
| async def __call__( | |
| self, | |
| data: Prompt, | |
| ): | |
| data = dict(data) | |
| prompt = data["prompt"] | |
| if not prompt: | |
| prompt = "an image of " | |
| complete_prompt = self.challenge_prompt.infer_prompt( | |
| [prompt], max_generation_length=77, sampling_topk=100 | |
| )[0].strip() | |
| return complete_prompt | |
| async def serve_index(self): | |
| with open("index.html", "r") as file: | |
| return HTMLResponse(content=file.read(), status_code=200) | |
| if __name__ == "__main__": | |
| args = get_args() | |
| print("Args: ", args) | |
| app = ChallengeImage() | |
| uvicorn.run(app.app, host="0.0.0.0", port=args.port) | |