ma4389 commited on
Commit
da82c81
·
verified ·
1 Parent(s): 1774921

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +50 -0
  2. cancer_model.pth +3 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import models, transforms
5
+ from PIL import Image
6
+
7
+ # Load the trained model
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ # Define model
11
+ model = models.resnet50(weights=None)
12
+ in_features = model.fc.in_features
13
+ model.fc = nn.Sequential(
14
+ nn.Linear(in_features, 512),
15
+ nn.ReLU(),
16
+ nn.Dropout(0.4),
17
+ nn.Linear(512, 47) # 47 classes
18
+ )
19
+ model.load_state_dict(torch.load("cancer_model.pth", map_location=device))
20
+ model.to(device)
21
+ model.eval()
22
+
23
+ # Label mapping (update this with actual class names if available)
24
+ class_names = [f"Class {i}" for i in range(47)]
25
+
26
+ # Transforms (same as validation)
27
+ transform = transforms.Compose([
28
+ transforms.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
29
+ transforms.Resize((224, 224)),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
32
+ ])
33
+
34
+ # Prediction function
35
+ def predict(img):
36
+ img = transform(img).unsqueeze(0).to(device)
37
+ with torch.no_grad():
38
+ outputs = model(img)
39
+ probs = torch.softmax(outputs, dim=1)
40
+ confidences, predicted = torch.max(probs, 1)
41
+ return {class_names[i]: float(probs[0][i]) for i in range(len(class_names))}
42
+
43
+ # Gradio UI
44
+ gr.Interface(
45
+ fn=predict,
46
+ inputs=gr.Image(type="pil"),
47
+ outputs=gr.Label(num_top_classes=5),
48
+ title="Multi-Cancer Classifier",
49
+ description="Upload a histopathology or cancer-related image. The model will predict its cancer type (47 classes)."
50
+ ).launch()
cancer_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df9ce7117323af782c866933372c1f24af027c32e374ee63409c866f4ec9d8a7
3
+ size 98644160
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch>=2.0
2
+ torchvision
3
+ gradio>=4.0
4
+ Pillow