import torch import torch.nn.functional as F import gradio as gr import numpy as np from PIL import Image from torchvision import transforms # استيراد النموذج الحقيقي from model import DenseShuffleGCANet, extract_handcrafted_features # الفئات CLASSES = ["NONE", "INFECTION", "ISCHAEMIA", "BOTH"] # الجهاز device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # تحميل النموذج model = DenseShuffleGCANet(num_classes=4, handcrafted_feature_dim=41) model.load_state_dict( torch.load("best_model_2.pth", map_location=device) ) model.to(device) model.eval() # تحويل الصورة transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # دالة التنبؤ def predict(image: Image.Image): image_tensor = transform(image).unsqueeze(0).to(device) features = extract_handcrafted_features(np.array(image)) features = features.unsqueeze(0).to(device) with torch.no_grad(): outputs = model(image_tensor, features) probs = F.softmax(outputs, dim=1)[0] result = {CLASSES[i]: float(probs[i]) for i in range(4)} predicted_class = CLASSES[int(torch.argmax(probs))] return result, predicted_class # واجهة Gradio interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload DFU Image"), outputs=[ gr.Label(num_top_classes=4, label="Probabilities"), gr.Textbox(label="Predicted Class") ], title="DFU Classification System", description="Classifies diabetic foot images into NONE, INFECTION, ISCHAEMIA, or BOTH." ) if __name__ == "__main__": interface.launch()