kyanmahajan commited on
Commit
3605994
·
verified ·
1 Parent(s): 7114d94

Upload 2 files

Browse files
Files changed (2) hide show
  1. resnet18_brain_tumor.pth +3 -0
  2. xray_classifier.py +106 -0
resnet18_brain_tumor.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d369161266ab606a0feb71d303704d4f2eb53d116c6473f42794e2d6c31d8841
3
+ size 44792676
xray_classifier.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import models, transforms
5
+ from flask import Flask, jsonify, request, render_template
6
+ from PIL import Image
7
+ import os
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from PIL import Image
11
+ import numpy as np
12
+ import cv2
13
+ import cv2
14
+ import torch
15
+
16
+ from pytorch_grad_cam import GradCAM
17
+ from pytorch_grad_cam.utils.image import show_cam_on_image
18
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
19
+
20
+ from flask_cors import CORS
21
+
22
+
23
+
24
+
25
+ app = Flask(__name__)
26
+ CORS(app)
27
+ os.makedirs("static", exist_ok=True)
28
+
29
+ # Device setup
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+
32
+
33
+ # Transform setup (same as training)
34
+ data_transforms = transforms.Compose([
35
+ transforms.Resize(256),
36
+ transforms.CenterCrop(224),
37
+ transforms.ToTensor(),
38
+ transforms.Normalize(
39
+ mean=[0.485, 0.456, 0.406],
40
+ std=[0.229, 0.224, 0.225]
41
+ )
42
+ ])
43
+
44
+
45
+ model = models.resnet18(pretrained=False);
46
+ model.fc = nn.Linear(model.fc.in_features, 3);
47
+ model.load_state_dict(torch.load("resnet18_brain_tumor.pth", map_location=device))
48
+
49
+ model.to(device)
50
+ model.eval()
51
+
52
+
53
+ class_names = [
54
+ "wound",
55
+ "brain",
56
+ "lung"
57
+ ]
58
+
59
+ # @app.route("/")
60
+ # def home():
61
+ # return render_template("index.html")
62
+
63
+ @app.route("/predict_classify", methods=["POST"])
64
+ def predict():
65
+ if "file" not in request.files:
66
+ return jsonify({"error": "No file provided"}), 400
67
+
68
+ file = request.files["file"]
69
+ filepath = os.path.join("static", file.filename)
70
+ file.save(filepath)
71
+
72
+ try:
73
+ image = Image.open(filepath).convert("RGB")
74
+ input_tensor = data_transforms(image).unsqueeze(0).to(device)
75
+
76
+ with torch.no_grad():
77
+ output = model(input_tensor)
78
+ pred_idx = torch.argmax(output, dim=1).item()
79
+ pred_label = class_names[pred_idx]
80
+
81
+
82
+
83
+
84
+
85
+ file={
86
+ "prediction": pred_label,
87
+
88
+ }
89
+ print(file)
90
+ return jsonify(file)
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+ except Exception as e:
100
+ return jsonify({"error": str(e)}), 500
101
+
102
+
103
+ if __name__ == '__main__':
104
+ app.run(debug=True, host="127.0.0.1", port=5000)
105
+
106
+