Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import onnxruntime as ort | |
| from torchvision.io import decode_image, ImageReadMode | |
| from torchvision.transforms.v2 import Compose, Resize, Normalize, ToDtype, Grayscale | |
| def sigmoid(x): | |
| return 1 / (1 + np.exp(-x)) | |
| def model(image, path_model): | |
| transform = Compose([ | |
| Resize([224, 224]), | |
| Grayscale(3), | |
| ToDtype(torch.float32, scale=True), | |
| Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| ort_session = ort.InferenceSession( | |
| path_model, | |
| providers=["CPUExecutionProvider"] | |
| ) | |
| input_name = ort_session.get_inputs()[0].name | |
| input_data = transform(image).unsqueeze(0).detach().cpu().numpy() | |
| onnxruntime_input = {input_name: input_data} | |
| # Perform inference | |
| onnxruntime_outputs = ort_session.run(None, onnxruntime_input)[0] | |
| return onnxruntime_outputs[0] | |
| def predict(image): | |
| classes = ['норма', 'регуляторные ошибки', 'пространственные ошибки'] | |
| predict_class_proba = model(image, 'clock_classifier_model.onnx') | |
| dict_predict = {classes[i]: sigmoid(predict_class_proba[i]) for i in range(len(classes))} | |
| predict_regress = np.round(model(image, 'clock_regression_model.onnx')) | |
| return dict_predict, predict_regress[0] | |
| def clock(input_image): | |
| image = decode_image(input_image, mode=ImageReadMode.RGB) | |
| return predict(image) | |
| description = """ | |
| <p>Этот искусственный интеллект был специально обучен для анализа <b>теста рисования часов</b> - одного из самых распространенных когнитивных тестов в неврологии.</p> | |
| <p>Модель оценивает рисунки по двум ключевым аспектам:</p> | |
| <ul> | |
| <li><b>Качественная оценка</b> - выявление конкретных типов ошибок: | |
| <ul> | |
| <li>Регуляторные (неправильное расположение стрелок)</li> | |
| <li>Пространственные (нарушение геометрии циферблата)</li> | |
| </ul></li> | |
| <li><b>Количественная оценка</b> - общий балл от 1 до 5, где: | |
| <ul> | |
| <li>5 - идеальное выполнение</li> | |
| <li>4 - легкие отклонения</li> | |
| <li>3-2 - умеренные нарушения</li> | |
| <li>1 - грубые отклонения</li> | |
| </ul></li> | |
| </ul> | |
| <p style="font-style:italic;background:#e8f4f8;padding:10px;border-left:4px solid #3498db;"> | |
| <b>Инструкция, по которой обучался ИИ:</b><br> | |
| "Нарисуйте круг, расставьте все цифры как на часах, укажите время 'десять минут пятого'" | |
| </p> | |
| """ | |
| demo = gr.Interface( | |
| fn=clock, | |
| inputs=gr.Image(label='Загрузите рисунок часов', type="filepath"), | |
| outputs=[ | |
| gr.Label(label="Качественная оценка"), | |
| gr.Number(label="Количественная оценка (1-5)") | |
| ], | |
| title="Анализ теста рисования часов", | |
| description=description, | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |