NTO-TCP-HF / src /api /mannequin_to_model_api.py
ishworrsubedii's picture
update: mto return base64
d187e5b
"""
project @ CTO_TCP_ZERO_GPU
created @ 2024-11-14
author @ github.com/ishworrsubedii
"""
import base64
import json
import os
import time
import requests
from fastapi.routing import APIRouter
from fastapi import File, UploadFile, Form
import replicate
from starlette.responses import JSONResponse
from src.api.nto_api import supabase
from src.utils.logger import logger
mto_router = APIRouter()
def run_mto(input):
try:
logger.info("Starting mannequin to model conversion")
output = replicate.run(
"xiankgx/face-swap:cff87316e31787df12002c9e20a78a017a36cb31fde9862d8dedd15ab29b7288",
input=input
)
logger.info("Mannequin to model conversion completed successfully")
return output
except Exception as e:
logger.error(f"Error in mannequin to model conversion: {str(e)}")
return None
def read_return(url):
try:
res = requests.get(url)
logger.info("Image fetched successfully")
return res.content
except Exception as e:
logger.error(f"Error fetching image: {str(e)}")
return None
@mto_router.post("/mto_image")
async def mto_image(image: UploadFile = File(...), store_name: str = Form(...),
clothing_category: str = Form(...),
product_id: str = Form(...),
body_structure: str = Form(...),
skin_complexion: str = Form(...),
facial_structure: str = Form(...), ):
start_time = time.time()
try:
logger.info(f"Starting MTO image process for store: {store_name}")
if body_structure == "medium":
body_structure = "fat"
logger.info("Body structure adjusted from 'medium' to 'fat'")
image_bytes = await image.read()
logger.info("Source image read successfully")
mannequin_image_url = f"https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/ClothingTryOn/{store_name}/{clothing_category}/{product_id}/{product_id}_{skin_complexion}_{facial_structure}_{body_structure}.webp"
logger.info(f"Fetching mannequin image")
reference_image_bytes = read_return(mannequin_image_url)
if reference_image_bytes is None:
logger.error("Failed to fetch reference image")
return JSONResponse({"error": "Failed to fetch reference image"}, status_code=500)
image_uri = f"data:image/jpeg;base64,{base64.b64encode(image_bytes).decode()}"
reference_image_uri = f"data:image/jpeg;base64,{base64.b64encode(reference_image_bytes).decode()}"
input = {
"local_source": image_uri,
"local_target": reference_image_uri
}
output = run_mto(input)
if output is None:
logger.error("Face swap process failed")
return JSONResponse({"error": "Face swap process failed"}, status_code=500)
try:
response = requests.get(str(output['image']))
image_content = response.content
base64_image = base64.b64encode(image_content).decode('utf-8')
logger.info("MTO image process completed successfully")
return JSONResponse(content={
"output": f"data:image/webp;base64,{base64_image}",
"status": "success",
"inference_time": round((time.time() - start_time), 2)
}, status_code=200)
except Exception as e:
logger.error(f"Error converting output to base64: {str(e)}")
return JSONResponse({"error": "Error processing output image"}, status_code=500)
except Exception as e:
logger.error(f"Error in MTO image process: {str(e)}")
return JSONResponse({"error": str(e)}, status_code=500)
@mto_router.get("/mannequin_catalogue")
async def returnJsonData(gender: str):
try:
logger.info(f"Fetching mannequin catalogue for gender: {gender}")
folderImageURL = supabase.storage.get_bucket("JSON").create_signed_url(
path=os.path.join("MannequinInfo.json"),
expires_in=3600
)["signedURL"]
logger.info("Fetching JSON data from Supabase")
r = requests.get(folderImageURL).content.decode()
mannequin_data = json.loads(r)
if gender.lower() == "female":
res = [item for item in mannequin_data if item["gender"] == "female"]
elif gender.lower() == "male":
res = [item for item in mannequin_data if item["gender"] == "male"]
else:
res = []
logger.warning(f"Invalid gender parameter: {gender}")
logger.info(f"Successfully retrieved {len(res)} mannequin entries")
return res
except Exception as e:
logger.error(f"Error in mannequin catalogue: {str(e)}")
return JSONResponse({"error": str(e)}, status_code=500)