limbsAI_API / app.py
Miguel Cid Flor
typos
6b54c6f
raw
history blame
1.54 kB
from fastapi import FastAPI, Request
import cv2
from fastapi.middleware.cors import CORSMiddleware
import numpy as np
import base64
from io import BytesIO
from PIL import Image
import torch # or tensorflow
from Models import ResPoseNet,transform
from PreProcessor import transform_data
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["http://127.0.0.1:5500"], # The domain from which you're making the request
allow_credentials=True,
allow_methods=["*"], # Allow all HTTP methods
allow_headers=["*"], # Allow all headers
)
# Load your model
model = ResPoseNet()
model.load_state_dict(torch.load('posev0.01126.pth', map_location=torch.device('cpu')))
def predict_keypoints(image,model):
model.eval()
with torch.no_grad():
img_tensor = transform(image).unsqueeze(0)
output =model(img_tensor)*224
keypoints = output.squeeze() # Remove extra dimension if necessary
points = [(keypoints[i].item(), keypoints[i+1].item()) for i in range(0, len(keypoints), 2)]
return points
def decode_base64_image(data):
header, encoded = data.split(",", 1)
img = Image.open(BytesIO(base64.b64decode(encoded)))
return np.array(img)
@app.post("/predict")
async def predict(request: Request):
data = await request.json()
img = decode_base64_image(data["image"])
processed, _ , reverse = transform_data(img,[])
results = predict_keypoints(processed,model)
keypoints = [reverse(x,y) for x, y in results]
return {"keypoints": keypoints}