File size: 2,120 Bytes
22c4d3c
 
 
831feed
22c4d3c
831feed
22c4d3c
 
831feed
 
 
22c4d3c
 
 
831feed
22c4d3c
831feed
22c4d3c
 
831feed
22c4d3c
831feed
22c4d3c
 
 
 
 
 
 
 
 
 
831feed
 
 
 
 
22c4d3c
 
 
 
831feed
22c4d3c
831feed
 
 
22c4d3c
 
 
 
 
 
68ca6ba
22c4d3c
831feed
 
 
22c4d3c
831feed
 
 
 
22c4d3c
831feed
 
22c4d3c
 
831feed
 
 
22c4d3c
 
 
 
 
831feed
 
 
22c4d3c
61118e0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn as nn
from torchvision import models, transforms
from flask import Flask, jsonify, request
from PIL import Image
import io
from flask_cors import CORS

# --------------------------
# Flask setup
# --------------------------
app = Flask(__name__)
CORS(app)

# --------------------------
# Device setup
# --------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --------------------------
# Transform setup (same as training)
# --------------------------
data_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# --------------------------
# Model setup
# --------------------------
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 3)
model.load_state_dict(torch.load("resnet18_brain_tumor.pth", map_location=device))
model.to(device)
model.eval()

class_names = ["wound", "brain", "lung"]

# --------------------------
# Predict route
# --------------------------
@app.route("/predict_classify", methods=["POST"])
def predict():
    if "file" not in request.files:
        return jsonify({"error": "No file provided"}), 400

    file = request.files["file"]

    try:
        # Load image directly from memory (no saving)
        image_bytes = file.read()
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")

        # Transform and prepare input
        input_tensor = data_transforms(image).unsqueeze(0).to(device)

        # Model inference
        with torch.no_grad():
            outputs = model(input_tensor)
            pred_idx = torch.argmax(outputs, dim=1).item()
            pred_label = class_names[pred_idx]

        return jsonify({
            "prediction": pred_label
        })

    except Exception as e:
        return jsonify({"error": str(e)}), 500


# --------------------------
# Run server
# --------------------------
if __name__ == '__main__':
    app.run(debug=True, host="0.0.0.0", port=7860)