import io import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from PIL import Image from flask import Flask, request, jsonify from flask_cors import CORS # CORS도 미리 추가해 둡니다. # --------------------------- # 1. 모델 정의 (파일에 저장된 구조와 이름으로 수정됨) # --------------------------- class EmotionCNN(nn.Module): def __init__(self, num_classes=3): super(EmotionCNN, self).__init__() # 'features.0' 대신 'conv1' (32 채널) self.conv1 = nn.Conv2d(3, 32, 3, padding=1) # 'features.3' 대신 'conv2' (64 채널) self.conv2 = nn.Conv2d(32, 64, 3, padding=1) # 'features.6' (128 채널)은 파일에 없으므로 삭제합니다. # nn.Sequential에 없던 헬퍼 모듈들 self.relu = nn.ReLU() self.pool = nn.MaxPool2d(2, 2) self.dropout = nn.Dropout(0.5) # 128x128 이미지가 2번의 pool을 거치면 32x32가 됩니다. # conv2의 채널이 64이므로, flatten된 크기는 (64 * 32 * 32) 입니다. # 'classifier.0' 대신 'fc1' self.fc1 = nn.Linear(64 * 32 * 32, 256) # 'classifier.3' 대신 'fc2' self.fc2 = nn.Linear(256, num_classes) # forward 함수도 Sequential 대신 수동으로 작성 def forward(self, x): # 128x128 -> 64x64 x = self.pool(self.relu(self.conv1(x))) # 64x64 -> 32x32 x = self.pool(self.relu(self.conv2(x))) # (3번째 conv 레이어 없음) # Flatten x = x.view(x.size(0), -1) # Classifier x = self.relu(self.fc1(x)) x = self.dropout(x) x = self.fc2(x) return x # --------------------------- # 2. 학습된 모델 로드 # --------------------------- MODEL_PATH = "emotion_cnn_model.pth" device = torch.device("cpu") # 이제 이 모델 구조는... model = EmotionCNN(num_classes=3).to(device) try: # ...이 state_dict와 완벽히 일치합니다! state_dict = torch.load(MODEL_PATH, map_location=device) model.load_state_dict(state_dict) model.eval() print("모델 로드 성공! (구조 일치)") except Exception as e: # 이 부분은 이제 실행되지 않아야 합니다. print(f"모델 로드 실패: {e}") class_names = ['Calm', 'Energetic', 'Happy'] # --------------------------- # 3. 이미지 전처리 # --------------------------- transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), ]) # --------------------------- # 4. Flask 앱 정의 # --------------------------- app = Flask(__name__) CORS(app) # CORS 허용 @app.route("/predict", methods=["POST"]) def predict(): if "image" not in request.files: return jsonify({"error": "image 파일이 없습니다. form-data에 'image'로 보내줘야 합니다."}), 400 file = request.files["image"] try: img = Image.open(file.stream).convert("RGB") except Exception as e: return jsonify({"error": f"이미지 로드 실패: {str(e)}"}), 400 # 전처리 img_tensor = transform(img).unsqueeze(0).to(device) # 추론 with torch.no_grad(): outputs = model(img_tensor) probs = torch.softmax(outputs, dim=1)[0].cpu().tolist() pred_idx = int(torch.argmax(outputs, dim=1)) result = { "predicted_label": class_names[pred_idx], "probabilities": { class_names[i]: float(probs[i]) for i in range(len(class_names)) } } return jsonify(result) @app.route("/", methods=["GET"]) def health_check(): return jsonify({"status": "ok", "message": "EmotionCNN backend v2 (구조 수정됨) running"}) # 이 파일은 'flask run' 명령어로 실행되므로 __main__ check는 필요 없습니다.