Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| import torchvision.transforms as transforms | |
| # Загрузка модели при старте | |
| print("Loading model...") | |
| model_path = hf_hub_download( | |
| repo_id="cwlachap/hrnet-cephalometric-landmark-detection", | |
| filename="best_model.pth" | |
| ) | |
| # Здесь нужна архитектура HRNet - упрощенная версия для демо | |
| class SimpleHRNet(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| # Simplified model architecture | |
| self.features = torch.nn.Sequential( | |
| torch.nn.Conv2d(3, 64, 3, padding=1), | |
| torch.nn.ReLU(), | |
| torch.nn.AdaptiveAvgPool2d((1, 1)) | |
| ) | |
| self.landmarks = torch.nn.Linear(64, 19 * 2) # 19 landmarks, x and y | |
| def forward(self, x): | |
| x = self.features(x) | |
| x = x.view(x.size(0), -1) | |
| x = self.landmarks(x) | |
| return x.view(-1, 19, 2) | |
| try: | |
| model = SimpleHRNet() | |
| checkpoint = torch.load(model_path, map_location='cpu') | |
| if 'model_state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['model_state_dict'], strict=False) | |
| model.eval() | |
| print("Model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| model = None | |
| # Названия landmarks | |
| LANDMARK_NAMES = [ | |
| "Sella", "Nasion", "Orbitale", "Porion", "Point A", "Point B", | |
| "Pogonion", "Menton", "Gnathion", "Gonion", "Lower Incisor Tip", | |
| "Upper Incisor Tip", "Upper Lip", "Lower Lip", "Subnasale", | |
| "Soft Tissue Pogonion", "Posterior Nasal Spine", "Anterior Nasal Spine", | |
| "Articulare" | |
| ] | |
| def analyze_image(image): | |
| if model is None: | |
| return {"error": "Model not loaded"}, None | |
| try: | |
| # Преобразование изображения | |
| transform = transforms.Compose([ | |
| transforms.Resize((768, 768)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| img = Image.fromarray(image).convert('RGB') | |
| img_tensor = transform(img).unsqueeze(0) | |
| # Inference | |
| with torch.no_grad(): | |
| landmarks = model(img_tensor) | |
| # Преобразование координат обратно к размеру изображения | |
| h, w = image.shape[:2] | |
| landmarks = landmarks.squeeze().cpu().numpy() | |
| landmarks[:, 0] = landmarks[:, 0] * w / 768 | |
| landmarks[:, 1] = landmarks[:, 1] * h / 768 | |
| # Формирование результата | |
| result = [] | |
| for i, (x, y) in enumerate(landmarks): | |
| result.append({ | |
| "landmark": LANDMARK_NAMES[i], | |
| "x": float(x), | |
| "y": float(y) | |
| }) | |
| return result, image | |
| except Exception as e: | |
| return {"error": str(e)}, None | |
| def process_for_api(image): | |
| result, _ = analyze_image(image) | |
| return result | |
| # Gradio интерфейс | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Cephalometric Landmark Detection") | |
| gr.Markdown("Upload a lateral cephalometric radiograph to detect 19 anatomical landmarks") | |
| with gr.Row(): | |
| input_image = gr.Image(label="Upload X-ray") | |
| output_json = gr.JSON(label="Detected Landmarks") | |
| analyze_btn = gr.Button("Analyze") | |
| analyze_btn.click(fn=analyze_image, inputs=input_image, outputs=[output_json, gr.Image(visible=False)]) | |
| # API endpoint | |
| api_demo = gr.Interface( | |
| fn=process_for_api, | |
| inputs=gr.Image(), | |
| outputs=gr.JSON(), | |
| api_name="predict" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |