CephAI / app.py
HavajOrtho's picture
Create app.py
6140361 verified
import gradio as gr
import torch
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
import torchvision.transforms as transforms
# Загрузка модели при старте
print("Loading model...")
model_path = hf_hub_download(
repo_id="cwlachap/hrnet-cephalometric-landmark-detection",
filename="best_model.pth"
)
# Здесь нужна архитектура HRNet - упрощенная версия для демо
class SimpleHRNet(torch.nn.Module):
def __init__(self):
super().__init__()
# Simplified model architecture
self.features = torch.nn.Sequential(
torch.nn.Conv2d(3, 64, 3, padding=1),
torch.nn.ReLU(),
torch.nn.AdaptiveAvgPool2d((1, 1))
)
self.landmarks = torch.nn.Linear(64, 19 * 2) # 19 landmarks, x and y
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.landmarks(x)
return x.view(-1, 19, 2)
try:
model = SimpleHRNet()
checkpoint = torch.load(model_path, map_location='cpu')
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
model.eval()
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
model = None
# Названия landmarks
LANDMARK_NAMES = [
"Sella", "Nasion", "Orbitale", "Porion", "Point A", "Point B",
"Pogonion", "Menton", "Gnathion", "Gonion", "Lower Incisor Tip",
"Upper Incisor Tip", "Upper Lip", "Lower Lip", "Subnasale",
"Soft Tissue Pogonion", "Posterior Nasal Spine", "Anterior Nasal Spine",
"Articulare"
]
def analyze_image(image):
if model is None:
return {"error": "Model not loaded"}, None
try:
# Преобразование изображения
transform = transforms.Compose([
transforms.Resize((768, 768)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img = Image.fromarray(image).convert('RGB')
img_tensor = transform(img).unsqueeze(0)
# Inference
with torch.no_grad():
landmarks = model(img_tensor)
# Преобразование координат обратно к размеру изображения
h, w = image.shape[:2]
landmarks = landmarks.squeeze().cpu().numpy()
landmarks[:, 0] = landmarks[:, 0] * w / 768
landmarks[:, 1] = landmarks[:, 1] * h / 768
# Формирование результата
result = []
for i, (x, y) in enumerate(landmarks):
result.append({
"landmark": LANDMARK_NAMES[i],
"x": float(x),
"y": float(y)
})
return result, image
except Exception as e:
return {"error": str(e)}, None
def process_for_api(image):
result, _ = analyze_image(image)
return result
# Gradio интерфейс
with gr.Blocks() as demo:
gr.Markdown("# Cephalometric Landmark Detection")
gr.Markdown("Upload a lateral cephalometric radiograph to detect 19 anatomical landmarks")
with gr.Row():
input_image = gr.Image(label="Upload X-ray")
output_json = gr.JSON(label="Detected Landmarks")
analyze_btn = gr.Button("Analyze")
analyze_btn.click(fn=analyze_image, inputs=input_image, outputs=[output_json, gr.Image(visible=False)])
# API endpoint
api_demo = gr.Interface(
fn=process_for_api,
inputs=gr.Image(),
outputs=gr.JSON(),
api_name="predict"
)
if __name__ == "__main__":
demo.launch()