plant-model-api / app.py
Sbzc's picture
fix plant-disease-model
126a28e
import gradio as gr
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms
# ResNet9 Model Tanımı
def ConvBlock(in_channels, out_channels, pool=False):
layers = [
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
]
if pool:
layers.append(nn.MaxPool2d(4))
return nn.Sequential(*layers)
class ResNet9(nn.Module):
def __init__(self, in_channels, num_diseases):
super().__init__()
self.conv1 = ConvBlock(in_channels, 64)
self.conv2 = ConvBlock(64, 128, pool=True)
self.res1 = nn.Sequential(ConvBlock(128, 128), ConvBlock(128, 128))
self.conv3 = ConvBlock(128, 256, pool=True)
self.conv4 = ConvBlock(256, 512, pool=True)
self.res2 = nn.Sequential(ConvBlock(512, 512), ConvBlock(512, 512))
self.classifier = nn.Sequential(
nn.MaxPool2d(4),
nn.Flatten(),
nn.Linear(512, num_diseases)
)
def forward(self, xb):
out = self.conv1(xb)
out = self.conv2(out)
out = self.res1(out) + out
out = self.conv3(out)
out = self.conv4(out)
out = self.res2(out) + out
out = self.classifier(out)
return out
# Hastalık isimleri - KEND─░N─░ZE GÖRE DE─×─░┼×T─░R─░N
CLASS_NAMES = [
'Apple___Apple_scab',
'Apple___Black_rot',
'Apple___Cedar_apple_rust',
'Apple___healthy',
'Blueberry___healthy',
'Cherry_(including_sour)___Powdery_mildew',
'Cherry_(including_sour)___healthy',
'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot',
'Corn_(maize)___Common_rust_',
'Corn_(maize)___Northern_Leaf_Blight',
'Corn_(maize)___healthy',
'Grape___Black_rot',
'Grape___Esca_(Black_Measles)',
'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
'Grape___healthy',
'Orange___Haunglongbing_(Citrus_greening)',
'Peach___Bacterial_spot',
'Peach___healthy',
'Pepper,_bell___Bacterial_spot',
'Pepper,_bell___healthy',
'Potato___Early_blight',
'Potato___Late_blight',
'Potato___healthy',
'Raspberry___healthy',
'Soybean___healthy',
'Squash___Powdery_mildew',
'Strawberry___Leaf_scorch',
'Strawberry___healthy',
'Tomato___Bacterial_spot',
'Tomato___Early_blight',
'Tomato___Late_blight',
'Tomato___Leaf_Mold',
'Tomato___Septoria_leaf_spot',
'Tomato___Spider_mites Two-spotted_spider_mite',
'Tomato___Target_Spot',
'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
'Tomato___Tomato_mosaic_virus',
'Tomato___healthy'
]
# Device ayarı
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Model yükleme
model = ResNet9(in_channels=3, num_diseases=len(CLASS_NAMES))
model.load_state_dict(torch.load('plant-disease-model.pth',
map_location=device,
weights_only=False))
model.to(device)
model.eval()
# Transform
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Tahmin fonksiyonu
def predict(image):
"""Resimden hastalık tahmini yapar"""
# Resmi hazırla
img_tensor = transform(image).unsqueeze(0).to(device)
# Tahmin yap
with torch.no_grad():
output = model(img_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# En yüksek 5 tahmini al
top5_prob, top5_idx = torch.topk(probabilities, 5)
# Sonuçları hazırla
results = {}
for i in range(5):
class_name = CLASS_NAMES[top5_idx[i].item()]
probability = top5_prob[i].item()
# İsmi daha okunabilir yap
display_name = class_name.replace('___', ' - ').replace('_', ' ')
results[display_name] = float(probability)
return results
# Gradio arayüzü
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="🌿 Bitki Resmi Yükleyin"),
outputs=gr.Label(num_top_classes=5, label="🔍 Tahmin Sonuçları"),
title="🌱 Plant Disease Detection",
description="Bitki yapraklarının resmini yükleyin, hastalık tespiti yapılsın! Model 38 farklı bitki hastalığını tespit edebilir.",
examples=[
# Örnek resimleri buraya ekleyebilirsiniz
],
theme="soft",
css="""
.gradio-container {
font-family: 'IBM Plex Sans', sans-serif;
}
"""
)
if __name__ == "__main__":
demo.launch()