Shadow / backend_app.py
donghyun13245's picture
Update backend_app.py
f099a15 verified
import io
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from flask import Flask, request, jsonify
from flask_cors import CORS # CORS๋„ ๋ฏธ๋ฆฌ ์ถ”๊ฐ€ํ•ด ๋‘ก๋‹ˆ๋‹ค.
# ---------------------------
# 1. ๋ชจ๋ธ ์ •์˜ (ํŒŒ์ผ์— ์ €์žฅ๋œ ๊ตฌ์กฐ์™€ ์ด๋ฆ„์œผ๋กœ ์ˆ˜์ •๋จ)
# ---------------------------
class EmotionCNN(nn.Module):
def __init__(self, num_classes=3):
super(EmotionCNN, self).__init__()
# 'features.0' ๋Œ€์‹  'conv1' (32 ์ฑ„๋„)
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
# 'features.3' ๋Œ€์‹  'conv2' (64 ์ฑ„๋„)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
# 'features.6' (128 ์ฑ„๋„)์€ ํŒŒ์ผ์— ์—†์œผ๋ฏ€๋กœ ์‚ญ์ œํ•ฉ๋‹ˆ๋‹ค.
# nn.Sequential์— ์—†๋˜ ํ—ฌํผ ๋ชจ๋“ˆ๋“ค
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(2, 2)
self.dropout = nn.Dropout(0.5)
# 128x128 ์ด๋ฏธ์ง€๊ฐ€ 2๋ฒˆ์˜ pool์„ ๊ฑฐ์น˜๋ฉด 32x32๊ฐ€ ๋ฉ๋‹ˆ๋‹ค.
# conv2์˜ ์ฑ„๋„์ด 64์ด๋ฏ€๋กœ, flatten๋œ ํฌ๊ธฐ๋Š” (64 * 32 * 32) ์ž…๋‹ˆ๋‹ค.
# 'classifier.0' ๋Œ€์‹  'fc1'
self.fc1 = nn.Linear(64 * 32 * 32, 256)
# 'classifier.3' ๋Œ€์‹  'fc2'
self.fc2 = nn.Linear(256, num_classes)
# forward ํ•จ์ˆ˜๋„ Sequential ๋Œ€์‹  ์ˆ˜๋™์œผ๋กœ ์ž‘์„ฑ
def forward(self, x):
# 128x128 -> 64x64
x = self.pool(self.relu(self.conv1(x)))
# 64x64 -> 32x32
x = self.pool(self.relu(self.conv2(x)))
# (3๋ฒˆ์งธ conv ๋ ˆ์ด์–ด ์—†์Œ)
# Flatten
x = x.view(x.size(0), -1)
# Classifier
x = self.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
# ---------------------------
# 2. ํ•™์Šต๋œ ๋ชจ๋ธ ๋กœ๋“œ
# ---------------------------
MODEL_PATH = "emotion_cnn_model.pth"
device = torch.device("cpu")
# ์ด์ œ ์ด ๋ชจ๋ธ ๊ตฌ์กฐ๋Š”...
model = EmotionCNN(num_classes=3).to(device)
try:
# ...์ด state_dict์™€ ์™„๋ฒฝํžˆ ์ผ์น˜ํ•ฉ๋‹ˆ๋‹ค!
state_dict = torch.load(MODEL_PATH, map_location=device)
model.load_state_dict(state_dict)
model.eval()
print("๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต! (๊ตฌ์กฐ ์ผ์น˜)")
except Exception as e:
# ์ด ๋ถ€๋ถ„์€ ์ด์ œ ์‹คํ–‰๋˜์ง€ ์•Š์•„์•ผ ํ•ฉ๋‹ˆ๋‹ค.
print(f"๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {e}")
class_names = ['Calm', 'Energetic', 'Happy']
# ---------------------------
# 3. ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
# ---------------------------
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
])
# ---------------------------
# 4. Flask ์•ฑ ์ •์˜
# ---------------------------
app = Flask(__name__)
CORS(app) # CORS ํ—ˆ์šฉ
@app.route("/predict", methods=["POST"])
def predict():
if "image" not in request.files:
return jsonify({"error": "image ํŒŒ์ผ์ด ์—†์Šต๋‹ˆ๋‹ค. form-data์— 'image'๋กœ ๋ณด๋‚ด์ค˜์•ผ ํ•ฉ๋‹ˆ๋‹ค."}), 400
file = request.files["image"]
try:
img = Image.open(file.stream).convert("RGB")
except Exception as e:
return jsonify({"error": f"์ด๋ฏธ์ง€ ๋กœ๋“œ ์‹คํŒจ: {str(e)}"}), 400
# ์ „์ฒ˜๋ฆฌ
img_tensor = transform(img).unsqueeze(0).to(device)
# ์ถ”๋ก 
with torch.no_grad():
outputs = model(img_tensor)
probs = torch.softmax(outputs, dim=1)[0].cpu().tolist()
pred_idx = int(torch.argmax(outputs, dim=1))
result = {
"predicted_label": class_names[pred_idx],
"probabilities": {
class_names[i]: float(probs[i]) for i in range(len(class_names))
}
}
return jsonify(result)
@app.route("/", methods=["GET"])
def health_check():
return jsonify({"status": "ok", "message": "EmotionCNN backend v2 (๊ตฌ์กฐ ์ˆ˜์ •๋จ) running"})
# ์ด ํŒŒ์ผ์€ 'flask run' ๋ช…๋ น์–ด๋กœ ์‹คํ–‰๋˜๋ฏ€๋กœ __main__ check๋Š” ํ•„์š” ์—†์Šต๋‹ˆ๋‹ค.