import gradio as gr import torch import numpy as np from PIL import Image from huggingface_hub import hf_hub_download import torchvision.transforms as transforms # Загрузка модели при старте print("Loading model...") model_path = hf_hub_download( repo_id="cwlachap/hrnet-cephalometric-landmark-detection", filename="best_model.pth" ) # Здесь нужна архитектура HRNet - упрощенная версия для демо class SimpleHRNet(torch.nn.Module): def __init__(self): super().__init__() # Simplified model architecture self.features = torch.nn.Sequential( torch.nn.Conv2d(3, 64, 3, padding=1), torch.nn.ReLU(), torch.nn.AdaptiveAvgPool2d((1, 1)) ) self.landmarks = torch.nn.Linear(64, 19 * 2) # 19 landmarks, x and y def forward(self, x): x = self.features(x) x = x.view(x.size(0), -1) x = self.landmarks(x) return x.view(-1, 19, 2) try: model = SimpleHRNet() checkpoint = torch.load(model_path, map_location='cpu') if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict'], strict=False) model.eval() print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") model = None # Названия landmarks LANDMARK_NAMES = [ "Sella", "Nasion", "Orbitale", "Porion", "Point A", "Point B", "Pogonion", "Menton", "Gnathion", "Gonion", "Lower Incisor Tip", "Upper Incisor Tip", "Upper Lip", "Lower Lip", "Subnasale", "Soft Tissue Pogonion", "Posterior Nasal Spine", "Anterior Nasal Spine", "Articulare" ] def analyze_image(image): if model is None: return {"error": "Model not loaded"}, None try: # Преобразование изображения transform = transforms.Compose([ transforms.Resize((768, 768)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) img = Image.fromarray(image).convert('RGB') img_tensor = transform(img).unsqueeze(0) # Inference with torch.no_grad(): landmarks = model(img_tensor) # Преобразование координат обратно к размеру изображения h, w = image.shape[:2] landmarks = landmarks.squeeze().cpu().numpy() landmarks[:, 0] = landmarks[:, 0] * w / 768 landmarks[:, 1] = landmarks[:, 1] * h / 768 # Формирование результата result = [] for i, (x, y) in enumerate(landmarks): result.append({ "landmark": LANDMARK_NAMES[i], "x": float(x), "y": float(y) }) return result, image except Exception as e: return {"error": str(e)}, None def process_for_api(image): result, _ = analyze_image(image) return result # Gradio интерфейс with gr.Blocks() as demo: gr.Markdown("# Cephalometric Landmark Detection") gr.Markdown("Upload a lateral cephalometric radiograph to detect 19 anatomical landmarks") with gr.Row(): input_image = gr.Image(label="Upload X-ray") output_json = gr.JSON(label="Detected Landmarks") analyze_btn = gr.Button("Analyze") analyze_btn.click(fn=analyze_image, inputs=input_image, outputs=[output_json, gr.Image(visible=False)]) # API endpoint api_demo = gr.Interface( fn=process_for_api, inputs=gr.Image(), outputs=gr.JSON(), api_name="predict" ) if __name__ == "__main__": demo.launch()