Omar Alam commited on
Commit
e04e07d
·
1 Parent(s): 676d96d

Push model

Browse files
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torchvision.models as models
6
+
7
+ from torchvision import transforms
8
+ from PIL import Image
9
+ import joblib
10
+ import pickle
11
+
12
+ # Load the classifier
13
+
14
+ classifier = joblib.load('logistic_model.pkl')
15
+
16
+ # Load label mappings
17
+
18
+ label_to_number, number_to_label = pickle.load(open('label_mappings.pkl', 'rb'))
19
+
20
+ # Load a pretrained ResNet model
21
+ model = models.resnet50(pretrained=True)
22
+ model = torch.nn.Sequential(*list(model.children())[:-1]) # Remove the classification layer
23
+ model.eval()
24
+
25
+ # Define image preprocessing
26
+ transform = transforms.Compose([
27
+ transforms.Resize((224, 224)),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize(
30
+ mean=[0.485, 0.456, 0.406],
31
+ std=[0.229, 0.224, 0.225]
32
+ ),
33
+ ])
34
+
35
+ # Function to extract embedding
36
+ def get_embedding(image_path):
37
+ image = Image.open(image_path).convert('RGB')
38
+ image = transform(image).unsqueeze(0) # Add batch dimension
39
+ with torch.no_grad():
40
+ embedding = model(image).squeeze() # Remove extra dimensions
41
+ return embedding.numpy()
42
+
43
+ # Prediction function
44
+ def classify_image(image):
45
+ embedding = get_embedding(image).reshape(1, -1)
46
+ pred = classifier.predict(embedding)
47
+ prob = classifier.predict_proba(embedding)
48
+ pred_label = pred[0]
49
+ pred_index = list(classifier.classes_).index(pred_label)
50
+ confidence = prob[0][pred_index]
51
+ return f"{number_to_label[pred_label]} ({confidence * 100:.2f}%)"
52
+
53
+ # Gradio UI
54
+ demo = gr.Interface(
55
+ fn=classify_image,
56
+ inputs=gr.Image(type="pil"),
57
+ outputs="text",
58
+ title="Mushroom Spore Classifier",
59
+ description="Upload a thumbnail of a mushroom spore and get its predicted class."
60
+ )
61
+
62
+ if __name__ == "__main__":
63
+ demo.launch()
knn_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c859faae2b30a52f9f5fd6a755b7373d03600318086c0f5584082e2d8a4673bb
3
+ size 11242916
label_mappings.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a11339a3c3ec38e7c4c4e94e0555c30576626746c52655e51b9bb644a0fd025e
3
+ size 52930
label_mappings.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2cbb2b6149096a433ad7848dbca0d89bdbe8651c8e925e57c051835a7772979
3
+ size 33190
logistic_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5cf9d4e355e2df40988bf6ad6b0ec17ecc5ba453263ac7ea33255da367f09498
3
+ size 16663239
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ numpy
5
+ scikit-learn
6
+ Pillow
7
+ joblib