File size: 3,329 Bytes
279ac73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
import torch
import numpy as np
import onnxruntime as ort
from torchvision.io import decode_image, ImageReadMode
from  torchvision.transforms.v2 import Compose, Resize, Normalize, ToDtype, Grayscale


def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def model(image, path_model):
    transform = Compose([
        Resize([224, 224]),
        Grayscale(3),
        ToDtype(torch.float32, scale=True),
        Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    ort_session = ort.InferenceSession(
        path_model,
        providers=["CPUExecutionProvider"]
    )

    input_name = ort_session.get_inputs()[0].name
    input_data = transform(image).unsqueeze(0).detach().cpu().numpy()
    onnxruntime_input = {input_name: input_data}

    # Perform inference
    onnxruntime_outputs = ort_session.run(None, onnxruntime_input)[0]
    return onnxruntime_outputs[0]

def predict(image):
    classes = ['норма', 'регуляторные ошибки', 'пространственные ошибки']
    predict_class_proba = model(image, 'clock_classifier_model.onnx')
    dict_predict = {classes[i]: sigmoid(predict_class_proba[i]) for i in range(len(classes))}
    
    predict_regress = np.round(model(image, 'clock_regression_model.onnx'))
    return  dict_predict, predict_regress[0]

def clock(input_image):
    image = decode_image(input_image, mode=ImageReadMode.RGB)
    return predict(image)



description = """
<p>Этот искусственный интеллект был специально обучен для анализа <b>теста рисования часов</b> - одного из самых распространенных когнитивных тестов в неврологии.</p>

<p>Модель оценивает рисунки по двум ключевым аспектам:</p>

<ul>
<li><b>Качественная оценка</b> - выявление конкретных типов ошибок:
<ul>
<li>Регуляторные (неправильное расположение стрелок)</li>
<li>Пространственные (нарушение геометрии циферблата)</li>
</ul></li>
<li><b>Количественная оценка</b> - общий балл от 1 до 5, где:
<ul>
<li>5 - идеальное выполнение</li>
<li>4 - легкие отклонения</li>
<li>3-2 - умеренные нарушения</li>
<li>1 - грубые отклонения</li>
</ul></li>
</ul>

<p style="font-style:italic;background:#e8f4f8;padding:10px;border-left:4px solid #3498db;">
<b>Инструкция, по которой обучался ИИ:</b><br>
"Нарисуйте круг, расставьте все цифры как на часах, укажите время 'десять минут пятого'"
</p>
"""

demo = gr.Interface(
    fn=clock,
    inputs=gr.Image(label='Загрузите рисунок часов', type="filepath"),
    outputs=[
        gr.Label(label="Качественная оценка"),
        gr.Number(label="Количественная оценка (1-5)")
    ],
    title="Анализ теста рисования часов",
    description=description,
    allow_flagging="never"
)

if __name__ == "__main__":
    demo.launch()