Spaces:
Sleeping
Sleeping
| import io | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from PIL import Image | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS # CORS๋ ๋ฏธ๋ฆฌ ์ถ๊ฐํด ๋ก๋๋ค. | |
| # --------------------------- | |
| # 1. ๋ชจ๋ธ ์ ์ (ํ์ผ์ ์ ์ฅ๋ ๊ตฌ์กฐ์ ์ด๋ฆ์ผ๋ก ์์ ๋จ) | |
| # --------------------------- | |
| class EmotionCNN(nn.Module): | |
| def __init__(self, num_classes=3): | |
| super(EmotionCNN, self).__init__() | |
| # 'features.0' ๋์ 'conv1' (32 ์ฑ๋) | |
| self.conv1 = nn.Conv2d(3, 32, 3, padding=1) | |
| # 'features.3' ๋์ 'conv2' (64 ์ฑ๋) | |
| self.conv2 = nn.Conv2d(32, 64, 3, padding=1) | |
| # 'features.6' (128 ์ฑ๋)์ ํ์ผ์ ์์ผ๋ฏ๋ก ์ญ์ ํฉ๋๋ค. | |
| # nn.Sequential์ ์๋ ํฌํผ ๋ชจ๋๋ค | |
| self.relu = nn.ReLU() | |
| self.pool = nn.MaxPool2d(2, 2) | |
| self.dropout = nn.Dropout(0.5) | |
| # 128x128 ์ด๋ฏธ์ง๊ฐ 2๋ฒ์ pool์ ๊ฑฐ์น๋ฉด 32x32๊ฐ ๋ฉ๋๋ค. | |
| # conv2์ ์ฑ๋์ด 64์ด๋ฏ๋ก, flatten๋ ํฌ๊ธฐ๋ (64 * 32 * 32) ์ ๋๋ค. | |
| # 'classifier.0' ๋์ 'fc1' | |
| self.fc1 = nn.Linear(64 * 32 * 32, 256) | |
| # 'classifier.3' ๋์ 'fc2' | |
| self.fc2 = nn.Linear(256, num_classes) | |
| # forward ํจ์๋ Sequential ๋์ ์๋์ผ๋ก ์์ฑ | |
| def forward(self, x): | |
| # 128x128 -> 64x64 | |
| x = self.pool(self.relu(self.conv1(x))) | |
| # 64x64 -> 32x32 | |
| x = self.pool(self.relu(self.conv2(x))) | |
| # (3๋ฒ์งธ conv ๋ ์ด์ด ์์) | |
| # Flatten | |
| x = x.view(x.size(0), -1) | |
| # Classifier | |
| x = self.relu(self.fc1(x)) | |
| x = self.dropout(x) | |
| x = self.fc2(x) | |
| return x | |
| # --------------------------- | |
| # 2. ํ์ต๋ ๋ชจ๋ธ ๋ก๋ | |
| # --------------------------- | |
| MODEL_PATH = "emotion_cnn_model.pth" | |
| device = torch.device("cpu") | |
| # ์ด์ ์ด ๋ชจ๋ธ ๊ตฌ์กฐ๋... | |
| model = EmotionCNN(num_classes=3).to(device) | |
| try: | |
| # ...์ด state_dict์ ์๋ฒฝํ ์ผ์นํฉ๋๋ค! | |
| state_dict = torch.load(MODEL_PATH, map_location=device) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| print("๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต! (๊ตฌ์กฐ ์ผ์น)") | |
| except Exception as e: | |
| # ์ด ๋ถ๋ถ์ ์ด์ ์คํ๋์ง ์์์ผ ํฉ๋๋ค. | |
| print(f"๋ชจ๋ธ ๋ก๋ ์คํจ: {e}") | |
| class_names = ['Calm', 'Energetic', 'Happy'] | |
| # --------------------------- | |
| # 3. ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ | |
| # --------------------------- | |
| transform = transforms.Compose([ | |
| transforms.Resize((128, 128)), | |
| transforms.ToTensor(), | |
| ]) | |
| # --------------------------- | |
| # 4. Flask ์ฑ ์ ์ | |
| # --------------------------- | |
| app = Flask(__name__) | |
| CORS(app) # CORS ํ์ฉ | |
| def predict(): | |
| if "image" not in request.files: | |
| return jsonify({"error": "image ํ์ผ์ด ์์ต๋๋ค. form-data์ 'image'๋ก ๋ณด๋ด์ค์ผ ํฉ๋๋ค."}), 400 | |
| file = request.files["image"] | |
| try: | |
| img = Image.open(file.stream).convert("RGB") | |
| except Exception as e: | |
| return jsonify({"error": f"์ด๋ฏธ์ง ๋ก๋ ์คํจ: {str(e)}"}), 400 | |
| # ์ ์ฒ๋ฆฌ | |
| img_tensor = transform(img).unsqueeze(0).to(device) | |
| # ์ถ๋ก | |
| with torch.no_grad(): | |
| outputs = model(img_tensor) | |
| probs = torch.softmax(outputs, dim=1)[0].cpu().tolist() | |
| pred_idx = int(torch.argmax(outputs, dim=1)) | |
| result = { | |
| "predicted_label": class_names[pred_idx], | |
| "probabilities": { | |
| class_names[i]: float(probs[i]) for i in range(len(class_names)) | |
| } | |
| } | |
| return jsonify(result) | |
| def health_check(): | |
| return jsonify({"status": "ok", "message": "EmotionCNN backend v2 (๊ตฌ์กฐ ์์ ๋จ) running"}) | |
| # ์ด ํ์ผ์ 'flask run' ๋ช ๋ น์ด๋ก ์คํ๋๋ฏ๋ก __main__ check๋ ํ์ ์์ต๋๋ค. | |