ma4389 commited on
Commit
cf41001
·
verified ·
1 Parent(s): da82c81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -50
app.py CHANGED
@@ -1,50 +1,60 @@
1
- import gradio as gr
2
- import torch
3
- import torch.nn as nn
4
- from torchvision import models, transforms
5
- from PIL import Image
6
-
7
- # Load the trained model
8
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
-
10
- # Define model
11
- model = models.resnet50(weights=None)
12
- in_features = model.fc.in_features
13
- model.fc = nn.Sequential(
14
- nn.Linear(in_features, 512),
15
- nn.ReLU(),
16
- nn.Dropout(0.4),
17
- nn.Linear(512, 47) # 47 classes
18
- )
19
- model.load_state_dict(torch.load("cancer_model.pth", map_location=device))
20
- model.to(device)
21
- model.eval()
22
-
23
- # Label mapping (update this with actual class names if available)
24
- class_names = [f"Class {i}" for i in range(47)]
25
-
26
- # Transforms (same as validation)
27
- transform = transforms.Compose([
28
- transforms.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
29
- transforms.Resize((224, 224)),
30
- transforms.ToTensor(),
31
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
32
- ])
33
-
34
- # Prediction function
35
- def predict(img):
36
- img = transform(img).unsqueeze(0).to(device)
37
- with torch.no_grad():
38
- outputs = model(img)
39
- probs = torch.softmax(outputs, dim=1)
40
- confidences, predicted = torch.max(probs, 1)
41
- return {class_names[i]: float(probs[0][i]) for i in range(len(class_names))}
42
-
43
- # Gradio UI
44
- gr.Interface(
45
- fn=predict,
46
- inputs=gr.Image(type="pil"),
47
- outputs=gr.Label(num_top_classes=5),
48
- title="Multi-Cancer Classifier",
49
- description="Upload a histopathology or cancer-related image. The model will predict its cancer type (47 classes)."
50
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import models, transforms
5
+ from PIL import Image
6
+
7
+ # Load the trained model
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ # Define model
11
+ model = models.resnet50(weights=None)
12
+ in_features = model.fc.in_features
13
+ model.fc = nn.Sequential(
14
+ nn.Linear(in_features, 512),
15
+ nn.ReLU(),
16
+ nn.Dropout(0.4),
17
+ nn.Linear(512, 47) # 47 classes
18
+ )
19
+ model.load_state_dict(torch.load("cancer_model.pth", map_location=device))
20
+ model.to(device)
21
+ model.eval()
22
+
23
+ # Actual class names (replace generic labels)
24
+ class_names = [
25
+ "lymph_fl", "colon_bnt", "cervix_koc", "colon_aca", "lymph_cll", "Skin_Melanoma", "Skin_Chickenpox",
26
+ "brain_glioma_tumor", "kidney_tumor", "all_pro", "Skin_Dyshidrotic Eczema", "Skin_Ringworm",
27
+ "cervix_dyk", "pancreatic_normal", "lung_colon_aca", "lung_lung_scc", "oral_scc", "kidney_stone",
28
+ "all_benign", "kidney_normal", "Skin_Basal Cell Carcinoma", "Skin_Nail Fungus", "all_early",
29
+ "pancreatic_tumor", "brain_pituitary_tumor", "lymph_mcl", "lung_colon_n", "Skin_Dermato Fibroma",
30
+ "brain_meningioma_tumor", "brain_normal", "breast_benign", "Skin_Actinic Keratosis",
31
+ "lung_lung_aca", "Skin_Vascular Lesion", "oral_normal", "Skin_Seborrheic Keratosis",
32
+ "Skin_Normal Skin", "all_pre", "Skin_Pigmented Benign Keratosis", "cervix_pab", "cervix_sfi",
33
+ "breast_malignant", "Skin_Acne", "cervix_mep", "Skin_Squamous Cell Carcinoma", "Skin_Nevus", "kidney_cyst"
34
+ ]
35
+
36
+ # Transforms (must match training/validation)
37
+ transform = transforms.Compose([
38
+ transforms.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
39
+ transforms.Resize((224, 224)),
40
+ transforms.ToTensor(),
41
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
42
+ ])
43
+
44
+ # Prediction function
45
+ def predict(img):
46
+ img = transform(img).unsqueeze(0).to(device)
47
+ with torch.no_grad():
48
+ outputs = model(img)
49
+ probs = torch.softmax(outputs, dim=1)
50
+ confidences, predicted = torch.max(probs, 1)
51
+ return {class_names[i]: float(probs[0][i]) for i in range(len(class_names))}
52
+
53
+ # Gradio UI
54
+ gr.Interface(
55
+ fn=predict,
56
+ inputs=gr.Image(type="pil"),
57
+ outputs=gr.Label(num_top_classes=5),
58
+ title="Multi-Cancer Classifier",
59
+ description="Upload a histopathology or cancer-related image. The model will predict its cancer type from 47 classes."
60
+ ).launch()