Spaces:
Sleeping
Sleeping
2508181758
Browse files- main.py +58 -69
- web/script.js +20 -5
main.py
CHANGED
|
@@ -4,16 +4,14 @@ from typing import List, Dict, Any
|
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
import torch.optim as optim
|
| 7 |
-
# from torchvision import datasets -> datasetsライブラリを使うため削除
|
| 8 |
from torchvision import transforms
|
| 9 |
-
from torch.utils.data import DataLoader
|
| 10 |
import numpy as np
|
| 11 |
import base64
|
| 12 |
from io import BytesIO
|
| 13 |
from PIL import Image
|
| 14 |
import random
|
| 15 |
-
import
|
| 16 |
-
from datasets import load_dataset # ★★★ Hugging Face datasetsライブラリをインポート
|
| 17 |
|
| 18 |
# FastAPIアプリケーションインスタンスを作成
|
| 19 |
app = FastAPI()
|
|
@@ -105,96 +103,84 @@ class PlayerModel(nn.Module):
|
|
| 105 |
x = layer(x)
|
| 106 |
return x
|
| 107 |
|
| 108 |
-
# --- グローバル変数とデータ準備 (
|
|
|
|
| 109 |
device = torch.device("cpu")
|
| 110 |
-
|
| 111 |
-
# 1. Hugging Face HubからMNISTデータセットをロード
|
| 112 |
mnist_dataset = load_dataset("mnist")
|
|
|
|
| 113 |
|
| 114 |
-
# 2. torchvisionのtransformを定義
|
| 115 |
-
transform = transforms.Compose([
|
| 116 |
-
transforms.ToTensor(),
|
| 117 |
-
# transforms.Normalize((0.1307,), (0.3081,)) # 必要に応じて正規化
|
| 118 |
-
])
|
| 119 |
-
|
| 120 |
-
# 3. データセットにtransformを適用する関数を定義
|
| 121 |
def apply_transforms(examples):
|
| 122 |
-
# PIL Imageのリストをテンソルのリストに変換
|
| 123 |
examples['image'] = [transform(image.convert("L")) for image in examples['image']]
|
| 124 |
return examples
|
| 125 |
|
| 126 |
-
# 4. データセットにtransformを適用
|
| 127 |
mnist_dataset.set_transform(apply_transforms)
|
| 128 |
-
|
| 129 |
-
# 5. DataLoaderを準備
|
| 130 |
train_subset = mnist_dataset['train'].select(range(1000))
|
| 131 |
train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
|
| 132 |
|
| 133 |
-
# 6. テスト用の画像リストを作成 (DataLoaderと同じ (image, label) タプルの形式を維持)
|
| 134 |
test_images = []
|
| 135 |
-
# メモリ使用量を考慮し、テスト画像は1000個に絞る
|
| 136 |
test_subset_for_inference = mnist_dataset['test'].shuffle().select(range(1000))
|
| 137 |
for item in test_subset_for_inference:
|
| 138 |
-
|
| 139 |
-
image_tensor = item['image'].unsqueeze(0) # バッチ次元 (1) を追加
|
| 140 |
label_tensor = torch.tensor(item['label'])
|
| 141 |
test_images.append((image_tensor, label_tensor))
|
| 142 |
|
| 143 |
-
|
| 144 |
-
trained_player_model = None
|
| 145 |
|
| 146 |
-
# --- バックエンドロジック ---
|
| 147 |
def get_enemy():
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
current_enemy = {"image": image, "label": label}
|
| 151 |
|
| 152 |
-
img_pil = transforms.ToPILImage()(
|
| 153 |
buffered = BytesIO()
|
| 154 |
img_pil.save(buffered, format="PNG")
|
| 155 |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 156 |
|
| 157 |
-
return {
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
-
def
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
if not layer_configs:
|
| 162 |
-
return {"
|
| 163 |
-
|
| 164 |
try:
|
| 165 |
model = PlayerModel(layer_configs).to(device)
|
| 166 |
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
| 167 |
loss_fn = nn.CrossEntropyLoss()
|
| 168 |
|
| 169 |
model.train()
|
| 170 |
-
for epoch in range(3): # 3エポック学習
|
| 171 |
-
|
| 172 |
-
for batch_idx, batch in enumerate(train_loader):
|
| 173 |
data, target = batch['image'].to(device), batch['label'].to(device)
|
| 174 |
optimizer.zero_grad()
|
| 175 |
output = model(data)
|
| 176 |
loss = loss_fn(output, target)
|
| 177 |
loss.backward()
|
| 178 |
optimizer.step()
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
trained_player_model = model
|
| 182 |
-
return {"success": True, "message": "モデルの訓練が完了しました!"}
|
| 183 |
except Exception as e:
|
| 184 |
-
print(f"Error during training: {e}")
|
| 185 |
-
return {"
|
| 186 |
-
|
| 187 |
-
def run_inference():
|
| 188 |
-
global trained_player_model, current_enemy
|
| 189 |
-
if trained_player_model is None:
|
| 190 |
-
return {"error": "モデルが訓練されていません。"}
|
| 191 |
|
| 192 |
-
|
| 193 |
-
current_enemy = {"image": image, "label": label}
|
| 194 |
-
|
| 195 |
-
model = trained_player_model
|
| 196 |
model.eval()
|
| 197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
intermediate_outputs = {}
|
| 199 |
hooks = []
|
| 200 |
def get_hook(name):
|
|
@@ -206,7 +192,6 @@ def run_inference():
|
|
| 206 |
hooks.append(layer.register_forward_hook(get_hook(name)))
|
| 207 |
|
| 208 |
with torch.no_grad():
|
| 209 |
-
image_tensor = current_enemy["image"].to(device)
|
| 210 |
output = model(image_tensor)
|
| 211 |
|
| 212 |
for h in hooks: h.remove()
|
|
@@ -224,34 +209,38 @@ def run_inference():
|
|
| 224 |
weights[name + '_w'] = layer.weight.cpu().detach().numpy().tolist()
|
| 225 |
weights[name + '_b'] = layer.bias.cpu().detach().numpy().tolist()
|
| 226 |
|
| 227 |
-
is_correct = (prediction ==
|
| 228 |
-
|
| 229 |
-
img_pil = transforms.ToPILImage()(image.squeeze(0))
|
| 230 |
-
buffered = BytesIO()
|
| 231 |
-
img_pil.save(buffered, format="PNG")
|
| 232 |
-
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 233 |
|
|
|
|
| 234 |
return {
|
| 235 |
-
"prediction": prediction,
|
|
|
|
|
|
|
| 236 |
"confidence": confidence,
|
| 237 |
-
"image_b64":
|
| 238 |
"architecture": [{"type": "Input", "name": "input", "shape": [1, 28, 28]}] + model.architecture_info,
|
| 239 |
"outputs": intermediate_outputs,
|
| 240 |
"weights": weights
|
| 241 |
}
|
| 242 |
|
| 243 |
-
# --- FastAPI Endpoints
|
| 244 |
@app.get("/api/get_enemy")
|
| 245 |
async def get_enemy_endpoint():
|
| 246 |
return get_enemy()
|
| 247 |
|
| 248 |
-
@app.post("/api/train_player_model")
|
| 249 |
-
async def train_player_model_endpoint(layer_configs: List[Dict[str, Any]] = Body(...)):
|
| 250 |
-
return train_player_model(layer_configs)
|
| 251 |
-
|
| 252 |
@app.post("/api/run_inference")
|
| 253 |
-
async def run_inference_endpoint():
|
| 254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
|
| 256 |
-
# --- 静的ファイルの配信
|
| 257 |
app.mount("/", StaticFiles(directory="web", html=True), name="static")
|
|
|
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
import torch.optim as optim
|
|
|
|
| 7 |
from torchvision import transforms
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
import numpy as np
|
| 10 |
import base64
|
| 11 |
from io import BytesIO
|
| 12 |
from PIL import Image
|
| 13 |
import random
|
| 14 |
+
from datasets import load_dataset
|
|
|
|
| 15 |
|
| 16 |
# FastAPIアプリケーションインスタンスを作成
|
| 17 |
app = FastAPI()
|
|
|
|
| 103 |
x = layer(x)
|
| 104 |
return x
|
| 105 |
|
| 106 |
+
# --- グローバル変数とデータ準備 (ステートレス対応) ---
|
| 107 |
+
# これらの変数はサーバー起動時に一度だけ初期化され、リクエスト間で変更されない定数として扱う
|
| 108 |
device = torch.device("cpu")
|
|
|
|
|
|
|
| 109 |
mnist_dataset = load_dataset("mnist")
|
| 110 |
+
transform = transforms.Compose([transforms.ToTensor()])
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
def apply_transforms(examples):
|
|
|
|
| 113 |
examples['image'] = [transform(image.convert("L")) for image in examples['image']]
|
| 114 |
return examples
|
| 115 |
|
|
|
|
| 116 |
mnist_dataset.set_transform(apply_transforms)
|
|
|
|
|
|
|
| 117 |
train_subset = mnist_dataset['train'].select(range(1000))
|
| 118 |
train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
|
| 119 |
|
|
|
|
| 120 |
test_images = []
|
|
|
|
| 121 |
test_subset_for_inference = mnist_dataset['test'].shuffle().select(range(1000))
|
| 122 |
for item in test_subset_for_inference:
|
| 123 |
+
image_tensor = item['image'].unsqueeze(0)
|
|
|
|
| 124 |
label_tensor = torch.tensor(item['label'])
|
| 125 |
test_images.append((image_tensor, label_tensor))
|
| 126 |
|
| 127 |
+
# --- バックエンドロジック (ステートレス関数) ---
|
|
|
|
| 128 |
|
|
|
|
| 129 |
def get_enemy():
|
| 130 |
+
"""新しい敵の画像(base64)と正解ラベルを返す。サーバー側では状態を保持しない。"""
|
| 131 |
+
image_tensor, label_tensor = random.choice(test_images)
|
|
|
|
| 132 |
|
| 133 |
+
img_pil = transforms.ToPILImage()(image_tensor.squeeze(0))
|
| 134 |
buffered = BytesIO()
|
| 135 |
img_pil.save(buffered, format="PNG")
|
| 136 |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 137 |
|
| 138 |
+
return {
|
| 139 |
+
"image_b64": "data:image/png;base64," + img_str,
|
| 140 |
+
"label": label_tensor.item()
|
| 141 |
+
}
|
| 142 |
|
| 143 |
+
def run_inference(layer_configs: list, enemy_image_b64: str, enemy_label: int):
|
| 144 |
+
"""
|
| 145 |
+
リクエストごとにモデルを構築・訓練し、与えられた敵データで推論を実行する。
|
| 146 |
+
サーバー側では状態を一切保持しない。
|
| 147 |
+
"""
|
| 148 |
+
# 1. モデルをその場で構築し、訓練する
|
| 149 |
if not layer_configs:
|
| 150 |
+
return {"error": "モデルが空です。"}
|
|
|
|
| 151 |
try:
|
| 152 |
model = PlayerModel(layer_configs).to(device)
|
| 153 |
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
| 154 |
loss_fn = nn.CrossEntropyLoss()
|
| 155 |
|
| 156 |
model.train()
|
| 157 |
+
for epoch in range(3): # 毎回3エポック学習
|
| 158 |
+
for batch in train_loader:
|
|
|
|
| 159 |
data, target = batch['image'].to(device), batch['label'].to(device)
|
| 160 |
optimizer.zero_grad()
|
| 161 |
output = model(data)
|
| 162 |
loss = loss_fn(output, target)
|
| 163 |
loss.backward()
|
| 164 |
optimizer.step()
|
| 165 |
+
print("On-the-fly training for inference finished.")
|
|
|
|
|
|
|
|
|
|
| 166 |
except Exception as e:
|
| 167 |
+
print(f"Error during on-the-fly training: {e}")
|
| 168 |
+
return {"error": f"推論中のモデル構築・訓練エラー: {e}"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
+
# 2. クライアントから送られてきた敵画像で推論する
|
|
|
|
|
|
|
|
|
|
| 171 |
model.eval()
|
| 172 |
|
| 173 |
+
# Base64文字列から画像テンソルにデコード
|
| 174 |
+
try:
|
| 175 |
+
header, encoded = enemy_image_b64.split(",", 1)
|
| 176 |
+
image_data = base64.b64decode(encoded)
|
| 177 |
+
image_pil = Image.open(BytesIO(image_data)).convert("L")
|
| 178 |
+
image_tensor = transforms.ToTensor()(image_pil).unsqueeze(0).to(device)
|
| 179 |
+
except Exception as e:
|
| 180 |
+
print(f"Error decoding enemy image: {e}")
|
| 181 |
+
return {"error": f"敵画像のデコードエラー: {e}"}
|
| 182 |
+
|
| 183 |
+
# 3. 推論と中間出力のキャプチャ
|
| 184 |
intermediate_outputs = {}
|
| 185 |
hooks = []
|
| 186 |
def get_hook(name):
|
|
|
|
| 192 |
hooks.append(layer.register_forward_hook(get_hook(name)))
|
| 193 |
|
| 194 |
with torch.no_grad():
|
|
|
|
| 195 |
output = model(image_tensor)
|
| 196 |
|
| 197 |
for h in hooks: h.remove()
|
|
|
|
| 209 |
weights[name + '_w'] = layer.weight.cpu().detach().numpy().tolist()
|
| 210 |
weights[name + '_b'] = layer.bias.cpu().detach().numpy().tolist()
|
| 211 |
|
| 212 |
+
is_correct = (prediction == enemy_label)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
+
# 4. 結果をクライアントに返す
|
| 215 |
return {
|
| 216 |
+
"prediction": prediction,
|
| 217 |
+
"label": enemy_label,
|
| 218 |
+
"is_correct": is_correct,
|
| 219 |
"confidence": confidence,
|
| 220 |
+
"image_b64": enemy_image_b64, # 受け取った画像をそのまま返す
|
| 221 |
"architecture": [{"type": "Input", "name": "input", "shape": [1, 28, 28]}] + model.architecture_info,
|
| 222 |
"outputs": intermediate_outputs,
|
| 223 |
"weights": weights
|
| 224 |
}
|
| 225 |
|
| 226 |
+
# --- FastAPI Endpoints ---
|
| 227 |
@app.get("/api/get_enemy")
|
| 228 |
async def get_enemy_endpoint():
|
| 229 |
return get_enemy()
|
| 230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
@app.post("/api/run_inference")
|
| 232 |
+
async def run_inference_endpoint(payload: Dict[str, Any] = Body(...)):
|
| 233 |
+
"""
|
| 234 |
+
クライアントからモデル構成と敵データを受け取り、推論結果を返すエンドポイント。
|
| 235 |
+
"""
|
| 236 |
+
layer_configs = payload.get("layer_configs")
|
| 237 |
+
enemy_image_b64 = payload.get("enemy_image_b64")
|
| 238 |
+
enemy_label = payload.get("enemy_label")
|
| 239 |
+
|
| 240 |
+
if not all([layer_configs, enemy_image_b64, enemy_label is not None]):
|
| 241 |
+
return {"error": "リクエストのパラメータが不足しています。"}
|
| 242 |
+
|
| 243 |
+
return run_inference(layer_configs, enemy_image_b64, enemy_label)
|
| 244 |
|
| 245 |
+
# --- 静的ファイルの配信 ---
|
| 246 |
app.mount("/", StaticFiles(directory="web", html=True), name="static")
|
web/script.js
CHANGED
|
@@ -34,6 +34,7 @@ let isBattleInProgress = false; // ★★★ バトルループ中のフラグ
|
|
| 34 |
let draggedItem = null; // { type, layer, index }
|
| 35 |
let dragOverIndex = null; // 並び替え先のインデックス
|
| 36 |
let wasDroppedSuccessfully = false; // ★★★ このフラグを追加
|
|
|
|
| 37 |
let ENEMY_MAX_HP = 100;
|
| 38 |
const PLAYER_MAX_HP = 100;
|
| 39 |
|
|
@@ -285,14 +286,19 @@ function updateHpBars() {
|
|
| 285 |
}
|
| 286 |
|
| 287 |
async function fetchNewEnemy() {
|
| 288 |
-
// EelからFetch APIに変更
|
| 289 |
const response = await fetch('/api/get_enemy');
|
| 290 |
-
const
|
| 291 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
enemyMessage.textContent = '野生のMNISTモンスターが現れた!';
|
| 293 |
-
enemyImage.src =
|
| 294 |
enemyImage.classList.remove('hidden');
|
| 295 |
-
await animateBattleLog('', true);
|
| 296 |
}
|
| 297 |
|
| 298 |
// --- D&D Functions ---
|
|
@@ -891,7 +897,15 @@ async function handleBattle() {
|
|
| 891 |
await animateBattleLog('新たな敵をスキャン... 推論実行...');
|
| 892 |
|
| 893 |
// EelからFetch APIに変更
|
| 894 |
-
const inferenceResponse = await fetch('/api/run_inference', {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 895 |
const result = await inferenceResponse.json();
|
| 896 |
|
| 897 |
if (result.error) {
|
|
@@ -921,6 +935,7 @@ async function handleBattle() {
|
|
| 921 |
|
| 922 |
if (enemyHP > 0 && playerHP > 0) {
|
| 923 |
await sleep(1500);
|
|
|
|
| 924 |
}
|
| 925 |
}
|
| 926 |
|
|
|
|
| 34 |
let draggedItem = null; // { type, layer, index }
|
| 35 |
let dragOverIndex = null; // 並び替え先のインデックス
|
| 36 |
let wasDroppedSuccessfully = false; // ★★★ このフラグを追加
|
| 37 |
+
let currentEnemy = { image_b64: null, label: null }; // ★★★ クライアント側で敵の状態を保持
|
| 38 |
let ENEMY_MAX_HP = 100;
|
| 39 |
const PLAYER_MAX_HP = 100;
|
| 40 |
|
|
|
|
| 286 |
}
|
| 287 |
|
| 288 |
async function fetchNewEnemy() {
|
|
|
|
| 289 |
const response = await fetch('/api/get_enemy');
|
| 290 |
+
const enemyData = await response.json();
|
| 291 |
|
| 292 |
+
// ★★★ グローバル変数に保存
|
| 293 |
+
currentEnemy = {
|
| 294 |
+
image_b64: enemyData.image_b64,
|
| 295 |
+
label: enemyData.label
|
| 296 |
+
};
|
| 297 |
+
|
| 298 |
enemyMessage.textContent = '野生のMNISTモンスターが現れた!';
|
| 299 |
+
enemyImage.src = currentEnemy.image_b64;
|
| 300 |
enemyImage.classList.remove('hidden');
|
| 301 |
+
await animateBattleLog('', true);
|
| 302 |
}
|
| 303 |
|
| 304 |
// --- D&D Functions ---
|
|
|
|
| 897 |
await animateBattleLog('新たな敵をスキャン... 推論実行...');
|
| 898 |
|
| 899 |
// EelからFetch APIに変更
|
| 900 |
+
const inferenceResponse = await fetch('/api/run_inference', {
|
| 901 |
+
method: 'POST',
|
| 902 |
+
headers: { 'Content-Type': 'application/json' },
|
| 903 |
+
body: JSON.stringify({
|
| 904 |
+
layer_configs: playerLayers,
|
| 905 |
+
enemy_image_b64: currentEnemy.image_b64,
|
| 906 |
+
enemy_label: currentEnemy.label,
|
| 907 |
+
}),
|
| 908 |
+
});
|
| 909 |
const result = await inferenceResponse.json();
|
| 910 |
|
| 911 |
if (result.error) {
|
|
|
|
| 935 |
|
| 936 |
if (enemyHP > 0 && playerHP > 0) {
|
| 937 |
await sleep(1500);
|
| 938 |
+
await fetchNewEnemy(); // 次の敵を準備
|
| 939 |
}
|
| 940 |
}
|
| 941 |
|