Spaces:
Runtime error
Runtime error
| import os | |
| from typing import List | |
| from fastapi import APIRouter | |
| from src.libs.image import save_img, generate_img_index | |
| from src.libs.model import CganCols, get_model | |
| from src.libs.s3 import s3client | |
| from src.models.generate import GenerateResult, ImageResult | |
| from src.models.main import User, Method, UserIndex | |
| IMAGE_STORE_PATH = os.path.abspath("./src/store") | |
| BUCKET_NAME = "pimthaigans-image-container" | |
| # just make sure to have IMAGE_STORE_PATH folder created | |
| if not os.path.exists(IMAGE_STORE_PATH): | |
| os.makedirs(IMAGE_STORE_PATH) | |
| router = APIRouter( | |
| prefix="/generate", | |
| tags=["Generate"], | |
| responses={404: {"description": "Not found"}}, | |
| ) | |
| model = CganCols() | |
| async def info(): | |
| return {"info": "This is the generate endpoint"} | |
| async def status(): | |
| return {"status": "OK"} | |
| async def generate(user: UserIndex) -> GenerateResult: | |
| if user.method == Method.index: | |
| result: GenerateResult = await generate_index(user.user, user.index) | |
| return result | |
| result: GenerateResult = await generate_all(user.user) | |
| return result | |
| async def generate_index(user: User, index: int) -> GenerateResult: | |
| s3 = s3client() | |
| img_detail = s3uploadimage(user, s3, index) | |
| result: List[ImageResult] = [img_detail] | |
| s3.close() | |
| return GenerateResult(user=user, method=Method.index, result=result) | |
| async def generate_all(user: User): | |
| s3 = s3client() | |
| result: List[ImageResult] = [] | |
| for index in range(0, 88): | |
| img_detail = s3uploadimage(user, s3, index) | |
| result.append(img_detail) | |
| s3.close() | |
| return GenerateResult(user=user, method=Method.all, result=result) | |
| def s3uploadimage(user, s3, index): | |
| output_path = os.path.join( | |
| IMAGE_STORE_PATH, f"{user.uuid}-{str(index).zfill(2)}.png") | |
| used_model = model.model_cols[get_model(index)] | |
| image = generate_img_index(reloaded_model=used_model, index=index % 11) | |
| save_img(image, output_path) | |
| s3_path: str = f"{user.uuid}/{str(index).zfill(2)}.png" | |
| s3.upload_file(output_path, BUCKET_NAME, s3_path) | |
| image_url = f'https://{BUCKET_NAME}.s3.amazonaws.com/{s3_path}' | |
| img_detail = ImageResult(index=index, | |
| image_url=image_url) | |
| os.remove(output_path) | |
| return img_detail | |