File size: 17,782 Bytes
00700bc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 | import re
import ast
from typing import Dict, Any, List, Optional, Set, Tuple
from validators.base import BaseValidator, Validator
class DependencyResolver(BaseValidator, Validator):
"""ライブラリ依存性の検出と解決を行うクラス"""
def __init__(self, client=None):
"""依存性解決クラスの初期化"""
super().__init__(client)
self.validation_results = {}
# 一般的なライブラリとそのインポート文のマッピング
self.common_libraries = {
# Python 標準ライブラリ
"os": "import os",
"sys": "import sys",
"re": "import re",
"math": "import math",
"random": "import random",
"datetime": "import datetime",
"time": "import time",
"json": "import json",
"csv": "import csv",
"collections": "import collections",
"functools": "import functools",
"itertools": "import itertools",
"pathlib": "import pathlib",
"typing": "import typing",
"logging": "import logging",
"argparse": "import argparse",
"subprocess": "import subprocess",
"threading": "import threading",
"multiprocessing": "import multiprocessing",
"asyncio": "import asyncio",
"io": "import io",
# データ処理と科学計算
"numpy": "import numpy as np",
"pandas": "import pandas as pd",
"scipy": "import scipy",
"matplotlib": "import matplotlib.pyplot as plt",
"seaborn": "import seaborn as sns",
"sklearn": "from sklearn import ...",
"tensorflow": "import tensorflow as tf",
"torch": "import torch",
"keras": "import keras",
"sympy": "import sympy",
# Web 関連
"requests": "import requests",
"flask": "from flask import Flask",
"django": "from django import ...",
"fastapi": "from fastapi import FastAPI",
"beautifulsoup4": "from bs4 import BeautifulSoup",
"sqlalchemy": "from sqlalchemy import ...",
"aiohttp": "import aiohttp",
# テスト関連
"pytest": "import pytest",
"unittest": "import unittest",
"mock": "from unittest import mock",
# ユーティリティ
"tqdm": "from tqdm import tqdm",
"pyyaml": "import yaml",
"pillow": "from PIL import Image",
"click": "import click",
# 特定の物理シミュレーション関連
"pymunk": "import pymunk",
"pybox2d": "import Box2D",
"mujoco": "import mujoco",
"pybullet": "import pybullet",
"opensim": "import opensim",
"vpython": "import vpython",
"fenics": "import fenics",
"dolfinx": "import dolfinx",
"pygmsh": "import pygmsh",
"meshio": "import meshio",
"pde": "import pde",
"mpmath": "import mpmath",
"dedalus": "import dedalus"
}
# Python のビルトイン関数・型リスト
self.python_builtins = set([
"abs", "all", "any", "ascii", "bin", "bool", "breakpoint", "bytearray", "bytes",
"callable", "chr", "classmethod", "compile", "complex", "delattr", "dict", "dir",
"divmod", "enumerate", "eval", "exec", "filter", "float", "format", "frozenset",
"getattr", "globals", "hasattr", "hash", "help", "hex", "id", "input", "int",
"isinstance", "issubclass", "iter", "len", "list", "locals", "map", "max", "memoryview",
"min", "next", "object", "oct", "open", "ord", "pow", "print", "property", "range",
"repr", "reversed", "round", "set", "setattr", "slice", "sorted", "staticmethod",
"str", "sum", "super", "tuple", "type", "vars", "zip", "__import__",
# 例外型
"BaseException", "Exception", "ArithmeticError", "BufferError", "LookupError",
"AssertionError", "AttributeError", "EOFError", "FloatingPointError", "GeneratorExit",
"ImportError", "ModuleNotFoundError", "IndexError", "KeyError", "KeyboardInterrupt",
"MemoryError", "NameError", "NotImplementedError", "OSError", "OverflowError",
"RecursionError", "ReferenceError", "RuntimeError", "StopIteration", "StopAsyncIteration",
"SyntaxError", "IndentationError", "TabError", "SystemError", "SystemExit", "TypeError",
"UnboundLocalError", "UnicodeError", "UnicodeEncodeError", "UnicodeDecodeError",
"UnicodeTranslateError", "ValueError", "ZeroDivisionError"
])
async def validate(self, code, context=None):
"""依存関係の解決を実行する(インターフェース実装)"""
analysis_result = self.analyze_dependencies(code)
self.validation_results = analysis_result
# 依存関係の解決が必要な場合
if analysis_result.get("missing_imports"):
resolved_code = self.resolve_dependencies(code)
return resolved_code
return code
def get_result_summary(self):
"""検証結果の要約を返す(インターフェース実装)"""
if not self.validation_results:
return "依存関係の分析はまだ実行されていません。"
detected_imports = self.validation_results.get("detected_imports", [])
missing_imports = self.validation_results.get("missing_imports", [])
implicit_dependencies = self.validation_results.get("implicit_dependencies", [])
summary = f"検出されたインポート: {len(detected_imports)}個\n"
summary += f"暗黙的な依存関係: {len(implicit_dependencies)}個\n"
if missing_imports:
summary += "不足しているインポート:\n"
summary += "\n".join([f"- {imp}" for imp in missing_imports[:5]])
if len(missing_imports) > 5:
summary += f"\n...他{len(missing_imports) - 5}個"
else:
summary += "全ての依存関係が適切にインポートされています。"
return summary
def analyze_dependencies(self, code: str) -> Dict[str, Any]:
"""コードの依存関係を解析する"""
result = {
"detected_imports": [],
"missing_imports": [],
"unused_imports": [],
"implicit_dependencies": [],
"import_statements": []
}
try:
# コードのASTを解析
tree = ast.parse(code)
# インポート文を抽出
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for name in node.names:
result["detected_imports"].append(name.name)
result["import_statements"].append(f"import {name.name}")
elif isinstance(node, ast.ImportFrom):
module = node.module
for name in node.names:
result["detected_imports"].append(f"{module}.{name.name}")
result["import_statements"].append(f"from {module} import {name.name}")
# 暗黙的な依存関係を検出(コード内で参照されているが明示的にインポートされていない名前)
used_names = set()
for node in ast.walk(tree):
if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Load):
used_names.add(node.id)
# ビルトイン以外のトップレベルの名前を抽出
top_level_names = set()
for node in ast.walk(tree):
if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Store):
if not any(isinstance(parent, (ast.FunctionDef, ast.ClassDef)) for parent in ast.iter_child_nodes(tree)):
top_level_names.add(node.id)
# 使用されている名前から既知のライブラリを検出
for name in used_names:
# ビルトインでなく、トップレベルで定義されていない名前
if name not in self.python_builtins and name not in top_level_names:
# 名前がライブラリの一部として検出されるか確認
for lib_name, import_stmt in self.common_libraries.items():
if name == lib_name or name in ["np", "pd", "plt", "sns", "tf"] and self._is_common_alias(name, lib_name):
if lib_name not in result["detected_imports"] and f"{lib_name}." not in [d.split(".")[0] for d in result["detected_imports"]]:
result["implicit_dependencies"].append(name)
result["missing_imports"].append(self.common_libraries[lib_name])
# 特殊なライブラリ使用パターンの検出
self._detect_special_library_usage(code, result)
# 物理シミュレーション関連のパターン検出
self._detect_physics_simulation_libraries(code, result)
except SyntaxError as e:
print(f"[Error] Syntax error during dependency analysis: {str(e)}")
result["error"] = f"構文エラー: {str(e)}"
except Exception as e:
print(f"[Error] Exception during dependency analysis: {str(e)}")
result["error"] = f"解析エラー: {str(e)}"
return result
def _detect_special_library_usage(self, code: str, result: Dict[str, Any]) -> None:
"""特殊なライブラリ使用パターンを検出する"""
# Numpyの数学関数パターン
numpy_math_patterns = [
r'\bsin\s*\(', r'\bcos\s*\(', r'\btan\s*\(', r'\bexp\s*\(', r'\blog\s*\(',
r'\bsqrt\s*\(', r'\barray\s*\(', r'\bmatrix\s*\(', r'\bzeros\s*\(', r'\bones\s*\(',
r'\blinspace\s*\(', r'\barange\s*\(', r'\breshape\s*\(', r'\btranspose\s*\('
]
has_numpy_pattern = False
for pattern in numpy_math_patterns:
if re.search(pattern, code) and not re.search(r'math\.' + pattern[1:], code):
has_numpy_pattern = True
break
if has_numpy_pattern and "numpy" not in result["detected_imports"]:
result["implicit_dependencies"].append("numpy")
result["missing_imports"].append("import numpy as np")
# Pandas データフレーム操作パターン
pandas_patterns = [
r'\bDataFrame\s*\(', r'\.iloc\b', r'\.loc\b', r'\.groupby\s*\(', r'\.pivot\s*\(',
r'\.merge\s*\(', r'\.join\s*\(', r'\.concat\s*\(', r'\.read_csv\s*\(', r'\.to_csv\s*\('
]
has_pandas_pattern = False
for pattern in pandas_patterns:
if re.search(pattern, code):
has_pandas_pattern = True
break
if has_pandas_pattern and "pandas" not in result["detected_imports"]:
result["implicit_dependencies"].append("pandas")
result["missing_imports"].append("import pandas as pd")
# Matplotlib プロット操作パターン
matplotlib_patterns = [
r'\bplot\s*\(', r'\bscatter\s*\(', r'\bfigure\s*\(', r'\bsubplot\s*\(', r'\bbar\s*\(',
r'\bhistogram\s*\(', r'\blegend\s*\(', r'\btitle\s*\(', r'\bxlabel\s*\(', r'\bylabel\s*\('
]
has_matplotlib_pattern = False
for pattern in matplotlib_patterns:
if re.search(pattern, code) and not re.search(r'\bplt\.' + pattern[1:], code):
has_matplotlib_pattern = True
break
if has_matplotlib_pattern and "matplotlib" not in result["detected_imports"]:
result["implicit_dependencies"].append("matplotlib")
result["missing_imports"].append("import matplotlib.pyplot as plt")
def _detect_physics_simulation_libraries(self, code: str, result: Dict[str, Any]) -> None:
"""物理シミュレーション関連のライブラリ使用を検出する"""
# 物理シミュレーション関連のパターン
physics_patterns = {
"pymunk": [r'\bBody\s*\(', r'\bSpace\s*\(', r'\bShape\s*\(', r'\bCircle\s*\(', r'\bSegment\s*\('],
"pybox2d": [r'\bWorld\s*\(', r'\bBox2D\b', r'\bBody\s*\(', r'\bJoint\s*\('],
"vpython": [r'\bsphere\s*\(', r'\bbox\s*\(', r'\bcylinder\s*\(', r'\bvector\s*\(', r'\brate\s*\('],
"mujoco": [r'\bmj_\w+\s*\(', r'\bmjData\b', r'\bmjModel\b'],
"pybullet": [r'\bpybullet\b', r'\bcreateCylinder\s*\(', r'\bcreateBox\s*\(', r'\bgetBasePositionAndOrientation\s*\('],
"fenics": [r'\bMesh\s*\(', r'\bFunctionSpace\s*\(', r'\bFunction\s*\(', r'\bsplit\s*\(', r'\bproject\s*\(']
}
for lib_name, patterns in physics_patterns.items():
for pattern in patterns:
if re.search(pattern, code):
if lib_name not in result["detected_imports"]:
result["implicit_dependencies"].append(lib_name)
result["missing_imports"].append(self.common_libraries.get(lib_name, f"import {lib_name}"))
break
def _is_common_alias(self, alias: str, library: str) -> bool:
"""一般的なエイリアスかどうかをチェックする"""
common_aliases = {
"np": "numpy",
"pd": "pandas",
"plt": "matplotlib.pyplot",
"sns": "seaborn",
"tf": "tensorflow"
}
return common_aliases.get(alias) == library
def resolve_dependencies(self, code: str) -> str:
"""コードの依存関係を解決し、必要なインポート文を追加する"""
# まず依存関係を分析
analysis = self.analyze_dependencies(code)
# エラーがある場合は元のコードを返す
if "error" in analysis:
print(f"[Warning] Dependency analysis error: {analysis['error']}")
return code
# インポート文の追加が必要ない場合
if not analysis["missing_imports"]:
return code
# コードの先頭にインポート文を追加
missing_imports = list(set(analysis["missing_imports"])) # 重複を削除
# 標準ライブラリ、サードパーティライブラリ、アプリケーションコードの順に並べる
std_libs = []
third_party_libs = []
app_libs = []
for imp in missing_imports:
if imp.startswith(("import os", "import sys", "import re", "import math", "import datetime", "import collections", "import functools", "import itertools", "import pathlib", "import typing", "import json", "import csv")):
std_libs.append(imp)
elif imp.startswith(("import numpy", "import pandas", "import matplotlib", "import scipy", "import sklearn", "import tensorflow", "import torch", "import keras", "import requests", "import flask", "import django", "import fastapi", "import bs4", "import sqlalchemy")):
third_party_libs.append(imp)
else:
app_libs.append(imp)
# 整理したインポート文を結合
imports_block = "\n".join(std_libs)
if std_libs and (third_party_libs or app_libs):
imports_block += "\n"
imports_block += "\n".join(third_party_libs)
if third_party_libs and app_libs:
imports_block += "\n"
imports_block += "\n".join(app_libs)
# 既存のインポート文を検出
existing_imports_match = re.search(r'^((?:(?:import|from)\s+.*?\n)+)', code, re.MULTILINE)
if existing_imports_match:
# 既存のインポート文の後に追加
existing_imports = existing_imports_match.group(1)
code = code.replace(existing_imports, existing_imports + "\n" + imports_block + "\n")
else:
# コードの先頭に追加
shebang_match = re.search(r'^(#!.*?\n)', code)
encoding_match = re.search(r'^(# -*coding[=:]\s*[-\w.]+.*?\n)', code)
docstring_match = re.search(r'^(""".*?"""\n)', code, re.DOTALL)
if shebang_match:
# シェバンの後に追加
shebang = shebang_match.group(1)
code = code.replace(shebang, shebang + "\n" + imports_block + "\n\n")
elif encoding_match:
# エンコーディング宣言の後に追加
encoding = encoding_match.group(1)
code = code.replace(encoding, encoding + "\n" + imports_block + "\n\n")
elif docstring_match:
# ドキュメントストリングの後に追加
docstring = docstring_match.group(1)
code = code.replace(docstring, docstring + "\n" + imports_block + "\n\n")
else:
# 単純に先頭に追加
code = imports_block + "\n\n" + code
return code |