Spaces:
Runtime error
Runtime error
| import os | |
| from fastapi import FastAPI, WebSocket | |
| from YOLOv6.yolov6.core.inferer import Inferer | |
| import cv2 | |
| import yaml as YAML | |
| import json | |
| import csv | |
| import ssl | |
| import hashlib | |
| from entity import read_entities | |
| import imtool | |
| app = FastAPI() | |
| weights = './runs/train/exp27/weights/best_stop_aug_ckpt.pt' | |
| device = 'cpu' | |
| yaml = './data.yaml' | |
| img_size = [640, 640] | |
| half = False | |
| conf_thres = 0.5 | |
| iou_thres = 0.45 | |
| classes = None | |
| agnostic_nms = None | |
| max_det = 1000 | |
| try: | |
| with open(yaml, 'r') as f: | |
| classes_data = YAML.safe_load(f.read()) | |
| entities = read_entities('../data/entities.csv') | |
| certs = {} | |
| with os.scandir('../data/certs') as it: | |
| for entry in it: | |
| bco, ext = entry.name.split('.') | |
| if ext == 'cert': | |
| try: | |
| cert_dict = ssl._ssl._test_decode_cert(entry.path) | |
| with open(entry.path, 'r') as f: | |
| cert_dict.update({ | |
| 'fingerprint': hashlib.sha1( | |
| ssl.PEM_cert_to_DER_cert(f.read()) | |
| ).hexdigest() | |
| }) | |
| except Exception as e: | |
| print("Error decoding certificate: {:}".format(e)) | |
| else: | |
| name = entities[bco].name | |
| certs.update({name: cert_dict}) | |
| print(f'loaded {len(certs.keys())} certs, got {len(classes_data["names"])} classes') | |
| inferer = Inferer(weights, device, yaml, img_size, half) | |
| except Exception as e: | |
| print('error', e) | |
| async def root(): | |
| return {"message": "API is working"} | |
| async def websockets_cb(websocket: WebSocket): | |
| try: | |
| await websocket.accept() | |
| while True: | |
| data = await websocket.receive_text() | |
| img = imtool.read_base64(data) | |
| cv2.imwrite("debug.png", img) | |
| try: | |
| os.remove("debug.txt") | |
| except: | |
| pass | |
| inferer.load(img) | |
| ret = inferer.infer(conf_thres, iou_thres, classes, agnostic_nms, max_det) | |
| print(ret) | |
| await websocket.send_text(ret + '@@@@' + '[%d,%d,%d]'%img.shape) | |
| except Exception as e: | |
| print("got: ", e) | |
| async def send_classes(websocket: WebSocket): | |
| await websocket.accept() | |
| await websocket.send_text(json.dumps({ | |
| 'classes': classes_data, | |
| 'certs': certs | |
| })) | |
| await websocket.close() | |
| if __name__ == "__main__": | |
| import uvicorn | |
| config = uvicorn.Config("api:app", port=5000, log_level="info") | |
| server = uvicorn.Server(config) | |
| server.run() | |