Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Body | |
| from fastapi.staticfiles import StaticFiles | |
| from typing import List, Dict, Any | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torchvision import transforms | |
| from torch.utils.data import DataLoader | |
| import numpy as np | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| import random | |
| from datasets import load_dataset | |
| # FastAPIアプリケーションインスタンスを作成 | |
| app = FastAPI() | |
| # --- 動的なプレイヤーモデル (変更なし) --- | |
| class PlayerModel(nn.Module): | |
| def __init__(self, layer_configs): | |
| super(PlayerModel, self).__init__() | |
| self.layers = nn.ModuleList() | |
| self.architecture_info = [] | |
| self.hookable_layers = {} | |
| in_channels = 1 | |
| feature_map_size = 28 | |
| is_flattened = False | |
| for i, config in enumerate(layer_configs): | |
| layer_type = config['type'] | |
| name = f"{layer_type.lower()}_{len([info for info in self.architecture_info if info['type'] == layer_type])}" | |
| if layer_type in ['Conv2d', 'MaxPool2d', 'AvgPool2d']: | |
| is_flattened = False | |
| if layer_type == 'Conv2d': | |
| out_channels = config['params']['out_channels'] | |
| kernel_size = config['params']['kernel_size'] | |
| layer = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2) | |
| self.layers.append(layer) | |
| self.hookable_layers[name] = layer | |
| in_channels = out_channels | |
| self.architecture_info.append({"type": "Conv2d", "name": name, "shape": [out_channels, feature_map_size, feature_map_size]}) | |
| else: | |
| kernel_size = config['params']['kernel_size'] | |
| if layer_type == 'MaxPool2d': | |
| layer = nn.MaxPool2d(kernel_size=kernel_size, stride=kernel_size) | |
| else: | |
| layer = nn.AvgPool2d(kernel_size=kernel_size, stride=kernel_size) | |
| self.layers.append(layer) | |
| self.hookable_layers[name] = layer | |
| feature_map_size //= kernel_size | |
| self.architecture_info.append({"type": layer_type, "name": name, "shape": [in_channels, feature_map_size, feature_map_size]}) | |
| elif layer_type in ['ReLU', 'Dropout']: | |
| if layer_type == 'ReLU': | |
| self.layers.append(nn.ReLU()) | |
| else: | |
| p = config['params']['p'] | |
| self.layers.append(nn.Dropout(p=p)) | |
| self.architecture_info.append({"type": layer_type, "name": name}) | |
| elif layer_type == 'Flatten': | |
| if not is_flattened: | |
| layer = nn.Flatten() | |
| self.layers.append(layer) | |
| self.hookable_layers[name] = layer | |
| flat_features = in_channels * feature_map_size * feature_map_size | |
| in_channels = flat_features | |
| self.architecture_info.append({"type": "Flatten", "name": name, "shape": [flat_features]}) | |
| is_flattened = True | |
| elif layer_type in ['Linear', 'ResidualBlock']: | |
| if not is_flattened: | |
| auto_flatten_name = f"auto_flatten_{i}" | |
| self.layers.append(nn.Flatten()) | |
| flat_features = in_channels * feature_map_size * feature_map_size | |
| in_channels = flat_features | |
| self.architecture_info.append({"type": "Flatten", "name": auto_flatten_name, "shape": [flat_features]}) | |
| is_flattened = True | |
| if layer_type == 'Linear': | |
| out_features = config['params']['out_features'] | |
| layer = nn.Linear(in_channels, out_features) | |
| in_channels = out_features | |
| else: | |
| features = in_channels | |
| layer = nn.Linear(features, features) | |
| self.layers.append(layer) | |
| self.hookable_layers[name] = layer | |
| self.architecture_info.append({"type": layer_type, "name": name, "shape": [in_channels]}) | |
| if not self.layers or not isinstance(self.layers[-1], nn.Linear) or self.layers[-1].out_features != 10: | |
| if not is_flattened: | |
| self.layers.append(nn.Flatten()) | |
| final_in_features = in_channels * feature_map_size * feature_map_size | |
| else: | |
| final_in_features = in_channels | |
| output_layer = nn.Linear(final_in_features, 10) | |
| self.layers.append(output_layer) | |
| self.hookable_layers["linear_output"] = output_layer | |
| self.architecture_info.append({"type": "Linear", "name": "linear_output", "shape": [10]}) | |
| def forward(self, x): | |
| for layer in self.layers: | |
| x = layer(x) | |
| return x | |
| # --- グローバル変数とデータ準備 (ステートレス対応) --- | |
| # これらの変数はサーバー起動時に一度だけ初期化され、リクエスト間で変更されない定数として扱う | |
| device = torch.device("cpu") | |
| mnist_dataset = load_dataset("mnist") | |
| transform = transforms.Compose([transforms.ToTensor()]) | |
| def apply_transforms(examples): | |
| examples['image'] = [transform(image.convert("L")) for image in examples['image']] | |
| return examples | |
| mnist_dataset.set_transform(apply_transforms) | |
| train_subset = mnist_dataset['train'].select(range(1000)) | |
| train_loader = DataLoader(train_subset, batch_size=32, shuffle=True) | |
| test_images = [] | |
| test_subset_for_inference = mnist_dataset['test'].shuffle().select(range(1000)) | |
| for item in test_subset_for_inference: | |
| image_tensor = item['image'].unsqueeze(0) | |
| label_tensor = torch.tensor(item['label']) | |
| test_images.append((image_tensor, label_tensor)) | |
| # --- バックエンドロジック (ステートレス関数) --- | |
| def get_enemy(): | |
| """新しい敵の画像(base64)と正解ラベルを返す。サーバー側では状態を保持しない。""" | |
| image_tensor, label_tensor = random.choice(test_images) | |
| img_pil = transforms.ToPILImage()(image_tensor.squeeze(0)) | |
| buffered = BytesIO() | |
| img_pil.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| return { | |
| "image_b64": "data:image/png;base64," + img_str, | |
| "label": label_tensor.item() | |
| } | |
| def run_inference(layer_configs: list, enemy_image_b64: str, enemy_label: int): | |
| """ | |
| リクエストごとにモデルを構築・訓練し、与えられた敵データで推論を実行する。 | |
| サーバー側では状態を一切保持しない。 | |
| """ | |
| # 1. モデルをその場で構築し、訓練する | |
| if not layer_configs: | |
| return {"error": "モデルが空です。"} | |
| try: | |
| model = PlayerModel(layer_configs).to(device) | |
| optimizer = optim.Adam(model.parameters(), lr=0.001) | |
| loss_fn = nn.CrossEntropyLoss() | |
| model.train() | |
| for epoch in range(3): # 毎回3エポック学習 | |
| for batch in train_loader: | |
| data, target = batch['image'].to(device), batch['label'].to(device) | |
| optimizer.zero_grad() | |
| output = model(data) | |
| loss = loss_fn(output, target) | |
| loss.backward() | |
| optimizer.step() | |
| print("On-the-fly training for inference finished.") | |
| except Exception as e: | |
| print(f"Error during on-the-fly training: {e}") | |
| return {"error": f"推論中のモデル構築・訓練エラー: {e}"} | |
| # 2. クライアントから送られてきた敵画像で推論する | |
| model.eval() | |
| # Base64文字列から画像テンソルにデコード | |
| try: | |
| header, encoded = enemy_image_b64.split(",", 1) | |
| image_data = base64.b64decode(encoded) | |
| image_pil = Image.open(BytesIO(image_data)).convert("L") | |
| image_tensor = transforms.ToTensor()(image_pil).unsqueeze(0).to(device) | |
| except Exception as e: | |
| print(f"Error decoding enemy image: {e}") | |
| return {"error": f"敵画像のデコードエラー: {e}"} | |
| # 3. 推論と中間出力のキャプチャ | |
| intermediate_outputs = {} | |
| hooks = [] | |
| def get_hook(name): | |
| def hook(model, input, output): | |
| intermediate_outputs[name] = output.detach().cpu().clone().numpy().tolist() | |
| return hook | |
| for name, layer in model.hookable_layers.items(): | |
| hooks.append(layer.register_forward_hook(get_hook(name))) | |
| with torch.no_grad(): | |
| output = model(image_tensor) | |
| for h in hooks: h.remove() | |
| probabilities = torch.nn.functional.softmax(output, dim=1) | |
| prediction = torch.argmax(probabilities, dim=1).item() | |
| confidence = probabilities[0, prediction].item() | |
| intermediate_outputs['input'] = image_tensor.cpu().numpy().tolist() | |
| weights = {} | |
| for name, layer in model.hookable_layers.items(): | |
| if isinstance(layer, (nn.Linear, nn.Conv2d)): | |
| if hasattr(layer, 'weight') and hasattr(layer, 'bias'): | |
| weights[name + '_w'] = layer.weight.cpu().detach().numpy().tolist() | |
| weights[name + '_b'] = layer.bias.cpu().detach().numpy().tolist() | |
| is_correct = (prediction == enemy_label) | |
| # 4. 結果をクライアントに返す | |
| return { | |
| "prediction": prediction, | |
| "label": enemy_label, | |
| "is_correct": is_correct, | |
| "confidence": confidence, | |
| "image_b64": enemy_image_b64, # 受け取った画像をそのまま返す | |
| "architecture": [{"type": "Input", "name": "input", "shape": [1, 28, 28]}] + model.architecture_info, | |
| "outputs": intermediate_outputs, | |
| "weights": weights | |
| } | |
| # --- FastAPI Endpoints --- | |
| async def get_enemy_endpoint(): | |
| return get_enemy() | |
| async def run_inference_endpoint(payload: Dict[str, Any] = Body(...)): | |
| """ | |
| クライアントからモデル構成と敵データを受け取り、推論結果を返すエンドポイント。 | |
| """ | |
| layer_configs = payload.get("layer_configs") | |
| enemy_image_b64 = payload.get("enemy_image_b64") | |
| enemy_label = payload.get("enemy_label") | |
| if not all([layer_configs, enemy_image_b64, enemy_label is not None]): | |
| return {"error": "リクエストのパラメータが不足しています。"} | |
| return run_inference(layer_configs, enemy_image_b64, enemy_label) | |
| # --- 静的ファイルの配信 --- | |
| app.mount("/", StaticFiles(directory="web", html=True), name="static") |