fadiyahalanazi commited on
Commit
3f0a595
Β·
verified Β·
1 Parent(s): 0f86136

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import subprocess
3
+ import sys
4
+
5
+ # βœ… Function to install missing packages efficiently
6
+ def install(package):
7
+ try:
8
+ __import__(package.split("==")[0]) # Try to import package before installing
9
+ except ImportError:
10
+ subprocess.run([sys.executable, "-m", "pip", "install", package])
11
+
12
+ # βœ… List of dependencies to install
13
+ dependencies = [
14
+ "torch>=2.0.0",
15
+ "torchvision>=0.15.0",
16
+ "transformers",
17
+ "gradio",
18
+ "pillow",
19
+ "pandas",
20
+ "opencv-python-headless",
21
+ "scikit-learn==1.3.0"
22
+ ]
23
+
24
+ # βœ… Install dependencies
25
+ for package in dependencies:
26
+ install(package)
27
+ import subprocess
28
+ import sys
29
+
30
+ # βœ… Function to install missing packages
31
+ def install(package):
32
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
33
+
34
+ # βœ… Ensure required libraries are installed
35
+ for package in ["torch", "torchvision", "transformers", "gradio", "pillow", "pandas", "opencv-python", "scikit-learn"]:
36
+ try:
37
+ __import__(package)
38
+ except ImportError:
39
+ install(package)
40
+
41
+ # βœ… Import libraries after installation
42
+ import torch
43
+ import torch.nn as nn
44
+ import torchvision.transforms as transforms
45
+ import torchvision.models as models
46
+ import pandas as pd
47
+ from PIL import Image
48
+ import gradio as gr
49
+ from sklearn.preprocessing import LabelEncoder
50
+
51
+ # βœ… Load metadata
52
+ CSV_PATH = "HAM10000_metadata.csv"
53
+ DATA_PATH = "ham10000_images/"
54
+
55
+ df = pd.read_csv(CSV_PATH)
56
+ label_encoder = LabelEncoder()
57
+ df["label"] = label_encoder.fit_transform(df["dx"]) # Convert labels to numbers
58
+ classes = label_encoder.classes_ # Get class names
59
+
60
+ # βœ… Define image transformation
61
+ transform = transforms.Compose([
62
+ transforms.Resize((224, 224)),
63
+ transforms.ToTensor(),
64
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
65
+ ])
66
+
67
+ # βœ… Load a pre-trained EfficientNet model
68
+ model = models.efficientnet_b0(pretrained=True)
69
+ num_features = model.classifier[1].in_features
70
+ model.classifier[1] = nn.Linear(num_features, len(classes)) # Adjust for 7 classes
71
+ model.load_state_dict(torch.load("ham10000_model.pth", map_location=torch.device('cpu')))
72
+ model.eval()
73
+
74
+ # βœ… Function to classify skin disease
75
+ def classify_skin_disease(image):
76
+ image = Image.fromarray(image) # Convert to PIL image
77
+ image = transform(image).unsqueeze(0) # Apply transformations
78
+
79
+ with torch.no_grad():
80
+ output = model(image)
81
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
82
+
83
+ # Convert probabilities to dictionary
84
+ results = {classes[i]: f"{probabilities[i].item():.2%}" for i in range(len(classes))}
85
+ return results
86
+
87
+ # βœ… Create Gradio Interface
88
+ iface = gr.Interface(
89
+ fn=classify_skin_disease,
90
+ inputs=gr.Image(type="numpy"),
91
+ outputs=gr.Label(num_top_classes=3),
92
+ title="🩺 AI Skin Disease Classifier",
93
+ description="πŸ“· Upload a skin lesion image and the model will classify it.",
94
+ examples=["example_eczema.jpg", "example_melanoma.jpg"], # Add sample images
95
+ )
96
+
97
+ # βœ… Run Gradio App
98
+ if __name__ == "__main__":
99
+ iface.launch()