Surajo commited on
Commit
7f8e126
·
verified ·
1 Parent(s): a06e8e2

Upload 3 files

Browse files
Files changed (3) hide show
  1. LR_model.pth +3 -0
  2. app.py +55 -0
  3. requirement.txt +4 -0
LR_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d152ffffdf42731a7ddb95c43975586b675db36966498bb4da0e70d5f7254652
3
+ size 288854283
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models, transforms
4
+ from PIL import Image
5
+ import gradio as gr
6
+
7
+ def load_model(path="LR_model.pth"):
8
+ model = models.resnet50(weights=None)
9
+
10
+ # Your saved model has a Sequential head, not just one linear layer
11
+ model.fc = nn.Sequential(
12
+ nn.Linear(model.fc.in_features, 256),
13
+ nn.ReLU(),
14
+ nn.Dropout(0.4),
15
+ nn.Linear(256, 2)
16
+ )
17
+
18
+ checkpoint = torch.load(path, map_location="cpu")
19
+ model.load_state_dict(checkpoint["model_state_dict"])
20
+ model.eval()
21
+ return model
22
+
23
+
24
+
25
+ # Image preprocessing
26
+ transform = transforms.Compose([
27
+ transforms.Resize((224, 224)),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize([0.4326, 0.4953, 0.3120], [0.2178, 0.2214, 0.2091])
30
+ ])
31
+
32
+ # Predict function
33
+ def predict(img):
34
+ img = img.convert("RGB")
35
+ tensor = transform(img).unsqueeze(0)
36
+ with torch.no_grad():
37
+ output = model(tensor)
38
+ probs = torch.nn.functional.softmax(output, dim=1)
39
+ idx = probs.argmax().item()
40
+ conf = probs[0][idx].item()
41
+ return {"Parasitized" if idx == 0 else "Uninfected": conf}
42
+
43
+ # Load model once
44
+ model = load_model()
45
+
46
+ # Gradio UI
47
+ interface = gr.Interface(
48
+ fn=predict,
49
+ inputs=gr.Image(type="pil"),
50
+ outputs=gr.Label(num_top_classes=2),
51
+ title=" Malaria Cell Detection",
52
+ description="Upload a blood smear cell image to check for malaria (parasitized or uninfected)."
53
+ )
54
+
55
+ interface.launch()
requirement.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ Pillow