datanerdke commited on
Commit
3db71fb
·
verified ·
1 Parent(s): ad8ada9

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -63
app.py CHANGED
@@ -1,63 +1,55 @@
1
- import torch
2
- import torch.nn as nn
3
- from torchvision import models, transforms
4
- from PIL import Image
5
- import gradio as gr
6
-
7
- # Load model
8
- def load_model(path="LR_model.pth"):
9
- model = models.resnet50(weights=None)
10
- model.fc = nn.Sequential(
11
- nn.Linear(model.fc.in_features, 256),
12
- nn.ReLU(),
13
- nn.Dropout(0.4),
14
- nn.Linear(256, 2)
15
- )
16
- checkpoint = torch.load(path, map_location="cpu")
17
- model.load_state_dict(checkpoint["model_state_dict"])
18
- model.eval()
19
- return model
20
-
21
- model = load_model()
22
-
23
- # Preprocessing
24
- transform = transforms.Compose([
25
- transforms.Resize((224, 224)),
26
- transforms.ToTensor(),
27
- transforms.Normalize([0.4326, 0.4953, 0.3120], [0.2178, 0.2214, 0.2091])
28
- ])
29
-
30
- # Prediction logic
31
- def predict(img):
32
- img = img.convert("RGB")
33
- tensor = transform(img).unsqueeze(0)
34
- with torch.no_grad():
35
- output = model(tensor)
36
- probs = torch.nn.functional.softmax(output, dim=1)
37
- idx = probs.argmax().item()
38
- conf = probs[0][idx].item()
39
- label = "🦠 Parasitized" if idx == 0 else "✅ Uninfected"
40
- return {label: conf}
41
-
42
- # Custom theme and layout using Gradio Blocks (v3)
43
- with gr.Blocks(theme=gr.themes.Soft()) as interface:
44
- gr.Markdown("# 🧬 Malaria Cell Detection App")
45
- gr.Markdown(
46
- "Upload a blood smear cell image to predict whether it contains malaria parasites or not. "
47
- "The model is based on ResNet50 fine-tuned for binary classification (Parasitized vs Uninfected)."
48
- )
49
-
50
- with gr.Row():
51
- image_input = gr.Image(type="pil", label="Upload Blood Smear Image", tool="editor")
52
- output_label = gr.Label(label="Prediction Confidence")
53
-
54
- analyze_btn = gr.Button("🔍 Analyze")
55
- analyze_btn.click(fn=predict, inputs=image_input, outputs=output_label)
56
-
57
- gr.Markdown("---")
58
- gr.Markdown(
59
- "📌 **Model Info**: Fine-tuned ResNet50 | Input size: 224x224 | Confidence shown as probability.\n"
60
- "🧪 This is a demo model. For real medical use, consult healthcare professionals."
61
- )
62
-
63
- interface.launch()
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models, transforms
4
+ from PIL import Image
5
+ import gradio as gr
6
+
7
+ def load_model(path="LR_model.pth"):
8
+ model = models.resnet50(weights=None)
9
+
10
+ # Your saved model has a Sequential head, not just one linear layer
11
+ model.fc = nn.Sequential(
12
+ nn.Linear(model.fc.in_features, 256),
13
+ nn.ReLU(),
14
+ nn.Dropout(0.4),
15
+ nn.Linear(256, 2)
16
+ )
17
+
18
+ checkpoint = torch.load(path, map_location="cpu")
19
+ model.load_state_dict(checkpoint["model_state_dict"])
20
+ model.eval()
21
+ return model
22
+
23
+
24
+
25
+ # Image preprocessing
26
+ transform = transforms.Compose([
27
+ transforms.Resize((224, 224)),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize([0.4326, 0.4953, 0.3120], [0.2178, 0.2214, 0.2091])
30
+ ])
31
+
32
+ # Predict function
33
+ def predict(img):
34
+ img = img.convert("RGB")
35
+ tensor = transform(img).unsqueeze(0)
36
+ with torch.no_grad():
37
+ output = model(tensor)
38
+ probs = torch.nn.functional.softmax(output, dim=1)
39
+ idx = probs.argmax().item()
40
+ conf = probs[0][idx].item()
41
+ return {"Parasitized" if idx == 0 else "Uninfected": conf}
42
+
43
+ # Load model once
44
+ model = load_model()
45
+
46
+ # Gradio UI
47
+ interface = gr.Interface(
48
+ fn=predict,
49
+ inputs=gr.Image(type="pil"),
50
+ outputs=gr.Label(num_top_classes=2),
51
+ title=" Malaria Cell Detection",
52
+ description="Upload a blood smear cell image to check for malaria (parasitized or uninfected)."
53
+ )
54
+
55
+ interface.launch()