Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request | |
| import cv2 | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import numpy as np | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| import torch # or tensorflow | |
| from Models import ResPoseNet,transform | |
| from PreProcessor import transform_data | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["http://127.0.0.1:5500"], # The domain from which you're making the request | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allow all HTTP methods | |
| allow_headers=["*"], # Allow all headers | |
| ) | |
| # Load your model | |
| model = ResPoseNet() | |
| model.load_state_dict(torch.load('posev0.01126.pth', map_location=torch.device('cpu'))) | |
| def predict_keypoints(image,model): | |
| model.eval() | |
| with torch.no_grad(): | |
| img_tensor = transform(image).unsqueeze(0) | |
| output =model(img_tensor)*224 | |
| keypoints = output.squeeze() # Remove extra dimension if necessary | |
| points = [(keypoints[i].item(), keypoints[i+1].item()) for i in range(0, len(keypoints), 2)] | |
| return points | |
| def decode_base64_image(data): | |
| header, encoded = data.split(",", 1) | |
| img = Image.open(BytesIO(base64.b64decode(encoded))) | |
| return np.array(img) | |
| async def predict(request: Request): | |
| data = await request.json() | |
| img = decode_base64_image(data["image"]) | |
| processed, _ , reverse = transform_data(img,[]) | |
| results = predict_keypoints(processed,model) | |
| keypoints = [reverse(x,y) for x, y in results] | |
| return {"keypoints": keypoints} | |