File size: 3,771 Bytes
6140361
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
import torch
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
import torchvision.transforms as transforms

# Загрузка модели при старте
print("Loading model...")
model_path = hf_hub_download(
    repo_id="cwlachap/hrnet-cephalometric-landmark-detection",
    filename="best_model.pth"
)

# Здесь нужна архитектура HRNet - упрощенная версия для демо
class SimpleHRNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # Simplified model architecture
        self.features = torch.nn.Sequential(
            torch.nn.Conv2d(3, 64, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.AdaptiveAvgPool2d((1, 1))
        )
        self.landmarks = torch.nn.Linear(64, 19 * 2)  # 19 landmarks, x and y
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.landmarks(x)
        return x.view(-1, 19, 2)

try:
    model = SimpleHRNet()
    checkpoint = torch.load(model_path, map_location='cpu')
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    model.eval()
    print("Model loaded successfully!")
except Exception as e:
    print(f"Error loading model: {e}")
    model = None

# Названия landmarks
LANDMARK_NAMES = [
    "Sella", "Nasion", "Orbitale", "Porion", "Point A", "Point B",
    "Pogonion", "Menton", "Gnathion", "Gonion", "Lower Incisor Tip",
    "Upper Incisor Tip", "Upper Lip", "Lower Lip", "Subnasale",
    "Soft Tissue Pogonion", "Posterior Nasal Spine", "Anterior Nasal Spine",
    "Articulare"
]

def analyze_image(image):
    if model is None:
        return {"error": "Model not loaded"}, None
    
    try:
        # Преобразование изображения
        transform = transforms.Compose([
            transforms.Resize((768, 768)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        img = Image.fromarray(image).convert('RGB')
        img_tensor = transform(img).unsqueeze(0)
        
        # Inference
        with torch.no_grad():
            landmarks = model(img_tensor)
        
        # Преобразование координат обратно к размеру изображения
        h, w = image.shape[:2]
        landmarks = landmarks.squeeze().cpu().numpy()
        landmarks[:, 0] = landmarks[:, 0] * w / 768
        landmarks[:, 1] = landmarks[:, 1] * h / 768
        
        # Формирование результата
        result = []
        for i, (x, y) in enumerate(landmarks):
            result.append({
                "landmark": LANDMARK_NAMES[i],
                "x": float(x),
                "y": float(y)
            })
        
        return result, image
    
    except Exception as e:
        return {"error": str(e)}, None

def process_for_api(image):
    result, _ = analyze_image(image)
    return result

# Gradio интерфейс
with gr.Blocks() as demo:
    gr.Markdown("# Cephalometric Landmark Detection")
    gr.Markdown("Upload a lateral cephalometric radiograph to detect 19 anatomical landmarks")
    
    with gr.Row():
        input_image = gr.Image(label="Upload X-ray")
        output_json = gr.JSON(label="Detected Landmarks")
    
    analyze_btn = gr.Button("Analyze")
    analyze_btn.click(fn=analyze_image, inputs=input_image, outputs=[output_json, gr.Image(visible=False)])

# API endpoint
api_demo = gr.Interface(
    fn=process_for_api,
    inputs=gr.Image(),
    outputs=gr.JSON(),
    api_name="predict"
)

if __name__ == "__main__":
    demo.launch()