File size: 2,120 Bytes
22c4d3c 831feed 22c4d3c 831feed 22c4d3c 831feed 22c4d3c 831feed 22c4d3c 831feed 22c4d3c 831feed 22c4d3c 831feed 22c4d3c 831feed 22c4d3c 831feed 22c4d3c 831feed 22c4d3c 68ca6ba 22c4d3c 831feed 22c4d3c 831feed 22c4d3c 831feed 22c4d3c 831feed 22c4d3c 831feed 22c4d3c 61118e0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import torch
import torch.nn as nn
from torchvision import models, transforms
from flask import Flask, jsonify, request
from PIL import Image
import io
from flask_cors import CORS
# --------------------------
# Flask setup
# --------------------------
app = Flask(__name__)
CORS(app)
# --------------------------
# Device setup
# --------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --------------------------
# Transform setup (same as training)
# --------------------------
data_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# --------------------------
# Model setup
# --------------------------
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 3)
model.load_state_dict(torch.load("resnet18_brain_tumor.pth", map_location=device))
model.to(device)
model.eval()
class_names = ["wound", "brain", "lung"]
# --------------------------
# Predict route
# --------------------------
@app.route("/predict_classify", methods=["POST"])
def predict():
if "file" not in request.files:
return jsonify({"error": "No file provided"}), 400
file = request.files["file"]
try:
# Load image directly from memory (no saving)
image_bytes = file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Transform and prepare input
input_tensor = data_transforms(image).unsqueeze(0).to(device)
# Model inference
with torch.no_grad():
outputs = model(input_tensor)
pred_idx = torch.argmax(outputs, dim=1).item()
pred_label = class_names[pred_idx]
return jsonify({
"prediction": pred_label
})
except Exception as e:
return jsonify({"error": str(e)}), 500
# --------------------------
# Run server
# --------------------------
if __name__ == '__main__':
app.run(debug=True, host="0.0.0.0", port=7860)
|