import gradio as gr import torch import torch.nn as nn from PIL import Image import torchvision.transforms as transforms # ResNet9 Model Tanımı def ConvBlock(in_channels, out_channels, pool=False): layers = [ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ] if pool: layers.append(nn.MaxPool2d(4)) return nn.Sequential(*layers) class ResNet9(nn.Module): def __init__(self, in_channels, num_diseases): super().__init__() self.conv1 = ConvBlock(in_channels, 64) self.conv2 = ConvBlock(64, 128, pool=True) self.res1 = nn.Sequential(ConvBlock(128, 128), ConvBlock(128, 128)) self.conv3 = ConvBlock(128, 256, pool=True) self.conv4 = ConvBlock(256, 512, pool=True) self.res2 = nn.Sequential(ConvBlock(512, 512), ConvBlock(512, 512)) self.classifier = nn.Sequential( nn.MaxPool2d(4), nn.Flatten(), nn.Linear(512, num_diseases) ) def forward(self, xb): out = self.conv1(xb) out = self.conv2(out) out = self.res1(out) + out out = self.conv3(out) out = self.conv4(out) out = self.res2(out) + out out = self.classifier(out) return out # Hastalık isimleri - KEND─░N─░ZE GÖRE DE─×─░┼×T─░R─░N CLASS_NAMES = [ 'Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy', 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy', 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus', 'Tomato___healthy' ] # Device ayarı device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Model yükleme model = ResNet9(in_channels=3, num_diseases=len(CLASS_NAMES)) model.load_state_dict(torch.load('plant-disease-model.pth', map_location=device, weights_only=False)) model.to(device) model.eval() # Transform transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Tahmin fonksiyonu def predict(image): """Resimden hastalık tahmini yapar""" # Resmi hazırla img_tensor = transform(image).unsqueeze(0).to(device) # Tahmin yap with torch.no_grad(): output = model(img_tensor) probabilities = torch.nn.functional.softmax(output[0], dim=0) # En yüksek 5 tahmini al top5_prob, top5_idx = torch.topk(probabilities, 5) # Sonuçları hazırla results = {} for i in range(5): class_name = CLASS_NAMES[top5_idx[i].item()] probability = top5_prob[i].item() # İsmi daha okunabilir yap display_name = class_name.replace('___', ' - ').replace('_', ' ') results[display_name] = float(probability) return results # Gradio arayüzü demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="🌿 Bitki Resmi Yükleyin"), outputs=gr.Label(num_top_classes=5, label="🔍 Tahmin Sonuçları"), title="🌱 Plant Disease Detection", description="Bitki yapraklarının resmini yükleyin, hastalık tespiti yapılsın! Model 38 farklı bitki hastalığını tespit edebilir.", examples=[ # Örnek resimleri buraya ekleyebilirsiniz ], theme="soft", css=""" .gradio-container { font-family: 'IBM Plex Sans', sans-serif; } """ ) if __name__ == "__main__": demo.launch()