groxy / validators /dependency_resolver.py
Yasu777's picture
Create dependency_resolver.py
00700bc verified
import re
import ast
from typing import Dict, Any, List, Optional, Set, Tuple
from validators.base import BaseValidator, Validator
class DependencyResolver(BaseValidator, Validator):
"""ライブラリ依存性の検出と解決を行うクラス"""
def __init__(self, client=None):
"""依存性解決クラスの初期化"""
super().__init__(client)
self.validation_results = {}
# 一般的なライブラリとそのインポート文のマッピング
self.common_libraries = {
# Python 標準ライブラリ
"os": "import os",
"sys": "import sys",
"re": "import re",
"math": "import math",
"random": "import random",
"datetime": "import datetime",
"time": "import time",
"json": "import json",
"csv": "import csv",
"collections": "import collections",
"functools": "import functools",
"itertools": "import itertools",
"pathlib": "import pathlib",
"typing": "import typing",
"logging": "import logging",
"argparse": "import argparse",
"subprocess": "import subprocess",
"threading": "import threading",
"multiprocessing": "import multiprocessing",
"asyncio": "import asyncio",
"io": "import io",
# データ処理と科学計算
"numpy": "import numpy as np",
"pandas": "import pandas as pd",
"scipy": "import scipy",
"matplotlib": "import matplotlib.pyplot as plt",
"seaborn": "import seaborn as sns",
"sklearn": "from sklearn import ...",
"tensorflow": "import tensorflow as tf",
"torch": "import torch",
"keras": "import keras",
"sympy": "import sympy",
# Web 関連
"requests": "import requests",
"flask": "from flask import Flask",
"django": "from django import ...",
"fastapi": "from fastapi import FastAPI",
"beautifulsoup4": "from bs4 import BeautifulSoup",
"sqlalchemy": "from sqlalchemy import ...",
"aiohttp": "import aiohttp",
# テスト関連
"pytest": "import pytest",
"unittest": "import unittest",
"mock": "from unittest import mock",
# ユーティリティ
"tqdm": "from tqdm import tqdm",
"pyyaml": "import yaml",
"pillow": "from PIL import Image",
"click": "import click",
# 特定の物理シミュレーション関連
"pymunk": "import pymunk",
"pybox2d": "import Box2D",
"mujoco": "import mujoco",
"pybullet": "import pybullet",
"opensim": "import opensim",
"vpython": "import vpython",
"fenics": "import fenics",
"dolfinx": "import dolfinx",
"pygmsh": "import pygmsh",
"meshio": "import meshio",
"pde": "import pde",
"mpmath": "import mpmath",
"dedalus": "import dedalus"
}
# Python のビルトイン関数・型リスト
self.python_builtins = set([
"abs", "all", "any", "ascii", "bin", "bool", "breakpoint", "bytearray", "bytes",
"callable", "chr", "classmethod", "compile", "complex", "delattr", "dict", "dir",
"divmod", "enumerate", "eval", "exec", "filter", "float", "format", "frozenset",
"getattr", "globals", "hasattr", "hash", "help", "hex", "id", "input", "int",
"isinstance", "issubclass", "iter", "len", "list", "locals", "map", "max", "memoryview",
"min", "next", "object", "oct", "open", "ord", "pow", "print", "property", "range",
"repr", "reversed", "round", "set", "setattr", "slice", "sorted", "staticmethod",
"str", "sum", "super", "tuple", "type", "vars", "zip", "__import__",
# 例外型
"BaseException", "Exception", "ArithmeticError", "BufferError", "LookupError",
"AssertionError", "AttributeError", "EOFError", "FloatingPointError", "GeneratorExit",
"ImportError", "ModuleNotFoundError", "IndexError", "KeyError", "KeyboardInterrupt",
"MemoryError", "NameError", "NotImplementedError", "OSError", "OverflowError",
"RecursionError", "ReferenceError", "RuntimeError", "StopIteration", "StopAsyncIteration",
"SyntaxError", "IndentationError", "TabError", "SystemError", "SystemExit", "TypeError",
"UnboundLocalError", "UnicodeError", "UnicodeEncodeError", "UnicodeDecodeError",
"UnicodeTranslateError", "ValueError", "ZeroDivisionError"
])
async def validate(self, code, context=None):
"""依存関係の解決を実行する(インターフェース実装)"""
analysis_result = self.analyze_dependencies(code)
self.validation_results = analysis_result
# 依存関係の解決が必要な場合
if analysis_result.get("missing_imports"):
resolved_code = self.resolve_dependencies(code)
return resolved_code
return code
def get_result_summary(self):
"""検証結果の要約を返す(インターフェース実装)"""
if not self.validation_results:
return "依存関係の分析はまだ実行されていません。"
detected_imports = self.validation_results.get("detected_imports", [])
missing_imports = self.validation_results.get("missing_imports", [])
implicit_dependencies = self.validation_results.get("implicit_dependencies", [])
summary = f"検出されたインポート: {len(detected_imports)}個\n"
summary += f"暗黙的な依存関係: {len(implicit_dependencies)}個\n"
if missing_imports:
summary += "不足しているインポート:\n"
summary += "\n".join([f"- {imp}" for imp in missing_imports[:5]])
if len(missing_imports) > 5:
summary += f"\n...他{len(missing_imports) - 5}個"
else:
summary += "全ての依存関係が適切にインポートされています。"
return summary
def analyze_dependencies(self, code: str) -> Dict[str, Any]:
"""コードの依存関係を解析する"""
result = {
"detected_imports": [],
"missing_imports": [],
"unused_imports": [],
"implicit_dependencies": [],
"import_statements": []
}
try:
# コードのASTを解析
tree = ast.parse(code)
# インポート文を抽出
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for name in node.names:
result["detected_imports"].append(name.name)
result["import_statements"].append(f"import {name.name}")
elif isinstance(node, ast.ImportFrom):
module = node.module
for name in node.names:
result["detected_imports"].append(f"{module}.{name.name}")
result["import_statements"].append(f"from {module} import {name.name}")
# 暗黙的な依存関係を検出(コード内で参照されているが明示的にインポートされていない名前)
used_names = set()
for node in ast.walk(tree):
if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Load):
used_names.add(node.id)
# ビルトイン以外のトップレベルの名前を抽出
top_level_names = set()
for node in ast.walk(tree):
if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Store):
if not any(isinstance(parent, (ast.FunctionDef, ast.ClassDef)) for parent in ast.iter_child_nodes(tree)):
top_level_names.add(node.id)
# 使用されている名前から既知のライブラリを検出
for name in used_names:
# ビルトインでなく、トップレベルで定義されていない名前
if name not in self.python_builtins and name not in top_level_names:
# 名前がライブラリの一部として検出されるか確認
for lib_name, import_stmt in self.common_libraries.items():
if name == lib_name or name in ["np", "pd", "plt", "sns", "tf"] and self._is_common_alias(name, lib_name):
if lib_name not in result["detected_imports"] and f"{lib_name}." not in [d.split(".")[0] for d in result["detected_imports"]]:
result["implicit_dependencies"].append(name)
result["missing_imports"].append(self.common_libraries[lib_name])
# 特殊なライブラリ使用パターンの検出
self._detect_special_library_usage(code, result)
# 物理シミュレーション関連のパターン検出
self._detect_physics_simulation_libraries(code, result)
except SyntaxError as e:
print(f"[Error] Syntax error during dependency analysis: {str(e)}")
result["error"] = f"構文エラー: {str(e)}"
except Exception as e:
print(f"[Error] Exception during dependency analysis: {str(e)}")
result["error"] = f"解析エラー: {str(e)}"
return result
def _detect_special_library_usage(self, code: str, result: Dict[str, Any]) -> None:
"""特殊なライブラリ使用パターンを検出する"""
# Numpyの数学関数パターン
numpy_math_patterns = [
r'\bsin\s*\(', r'\bcos\s*\(', r'\btan\s*\(', r'\bexp\s*\(', r'\blog\s*\(',
r'\bsqrt\s*\(', r'\barray\s*\(', r'\bmatrix\s*\(', r'\bzeros\s*\(', r'\bones\s*\(',
r'\blinspace\s*\(', r'\barange\s*\(', r'\breshape\s*\(', r'\btranspose\s*\('
]
has_numpy_pattern = False
for pattern in numpy_math_patterns:
if re.search(pattern, code) and not re.search(r'math\.' + pattern[1:], code):
has_numpy_pattern = True
break
if has_numpy_pattern and "numpy" not in result["detected_imports"]:
result["implicit_dependencies"].append("numpy")
result["missing_imports"].append("import numpy as np")
# Pandas データフレーム操作パターン
pandas_patterns = [
r'\bDataFrame\s*\(', r'\.iloc\b', r'\.loc\b', r'\.groupby\s*\(', r'\.pivot\s*\(',
r'\.merge\s*\(', r'\.join\s*\(', r'\.concat\s*\(', r'\.read_csv\s*\(', r'\.to_csv\s*\('
]
has_pandas_pattern = False
for pattern in pandas_patterns:
if re.search(pattern, code):
has_pandas_pattern = True
break
if has_pandas_pattern and "pandas" not in result["detected_imports"]:
result["implicit_dependencies"].append("pandas")
result["missing_imports"].append("import pandas as pd")
# Matplotlib プロット操作パターン
matplotlib_patterns = [
r'\bplot\s*\(', r'\bscatter\s*\(', r'\bfigure\s*\(', r'\bsubplot\s*\(', r'\bbar\s*\(',
r'\bhistogram\s*\(', r'\blegend\s*\(', r'\btitle\s*\(', r'\bxlabel\s*\(', r'\bylabel\s*\('
]
has_matplotlib_pattern = False
for pattern in matplotlib_patterns:
if re.search(pattern, code) and not re.search(r'\bplt\.' + pattern[1:], code):
has_matplotlib_pattern = True
break
if has_matplotlib_pattern and "matplotlib" not in result["detected_imports"]:
result["implicit_dependencies"].append("matplotlib")
result["missing_imports"].append("import matplotlib.pyplot as plt")
def _detect_physics_simulation_libraries(self, code: str, result: Dict[str, Any]) -> None:
"""物理シミュレーション関連のライブラリ使用を検出する"""
# 物理シミュレーション関連のパターン
physics_patterns = {
"pymunk": [r'\bBody\s*\(', r'\bSpace\s*\(', r'\bShape\s*\(', r'\bCircle\s*\(', r'\bSegment\s*\('],
"pybox2d": [r'\bWorld\s*\(', r'\bBox2D\b', r'\bBody\s*\(', r'\bJoint\s*\('],
"vpython": [r'\bsphere\s*\(', r'\bbox\s*\(', r'\bcylinder\s*\(', r'\bvector\s*\(', r'\brate\s*\('],
"mujoco": [r'\bmj_\w+\s*\(', r'\bmjData\b', r'\bmjModel\b'],
"pybullet": [r'\bpybullet\b', r'\bcreateCylinder\s*\(', r'\bcreateBox\s*\(', r'\bgetBasePositionAndOrientation\s*\('],
"fenics": [r'\bMesh\s*\(', r'\bFunctionSpace\s*\(', r'\bFunction\s*\(', r'\bsplit\s*\(', r'\bproject\s*\(']
}
for lib_name, patterns in physics_patterns.items():
for pattern in patterns:
if re.search(pattern, code):
if lib_name not in result["detected_imports"]:
result["implicit_dependencies"].append(lib_name)
result["missing_imports"].append(self.common_libraries.get(lib_name, f"import {lib_name}"))
break
def _is_common_alias(self, alias: str, library: str) -> bool:
"""一般的なエイリアスかどうかをチェックする"""
common_aliases = {
"np": "numpy",
"pd": "pandas",
"plt": "matplotlib.pyplot",
"sns": "seaborn",
"tf": "tensorflow"
}
return common_aliases.get(alias) == library
def resolve_dependencies(self, code: str) -> str:
"""コードの依存関係を解決し、必要なインポート文を追加する"""
# まず依存関係を分析
analysis = self.analyze_dependencies(code)
# エラーがある場合は元のコードを返す
if "error" in analysis:
print(f"[Warning] Dependency analysis error: {analysis['error']}")
return code
# インポート文の追加が必要ない場合
if not analysis["missing_imports"]:
return code
# コードの先頭にインポート文を追加
missing_imports = list(set(analysis["missing_imports"])) # 重複を削除
# 標準ライブラリ、サードパーティライブラリ、アプリケーションコードの順に並べる
std_libs = []
third_party_libs = []
app_libs = []
for imp in missing_imports:
if imp.startswith(("import os", "import sys", "import re", "import math", "import datetime", "import collections", "import functools", "import itertools", "import pathlib", "import typing", "import json", "import csv")):
std_libs.append(imp)
elif imp.startswith(("import numpy", "import pandas", "import matplotlib", "import scipy", "import sklearn", "import tensorflow", "import torch", "import keras", "import requests", "import flask", "import django", "import fastapi", "import bs4", "import sqlalchemy")):
third_party_libs.append(imp)
else:
app_libs.append(imp)
# 整理したインポート文を結合
imports_block = "\n".join(std_libs)
if std_libs and (third_party_libs or app_libs):
imports_block += "\n"
imports_block += "\n".join(third_party_libs)
if third_party_libs and app_libs:
imports_block += "\n"
imports_block += "\n".join(app_libs)
# 既存のインポート文を検出
existing_imports_match = re.search(r'^((?:(?:import|from)\s+.*?\n)+)', code, re.MULTILINE)
if existing_imports_match:
# 既存のインポート文の後に追加
existing_imports = existing_imports_match.group(1)
code = code.replace(existing_imports, existing_imports + "\n" + imports_block + "\n")
else:
# コードの先頭に追加
shebang_match = re.search(r'^(#!.*?\n)', code)
encoding_match = re.search(r'^(# -*coding[=:]\s*[-\w.]+.*?\n)', code)
docstring_match = re.search(r'^(""".*?"""\n)', code, re.DOTALL)
if shebang_match:
# シェバンの後に追加
shebang = shebang_match.group(1)
code = code.replace(shebang, shebang + "\n" + imports_block + "\n\n")
elif encoding_match:
# エンコーディング宣言の後に追加
encoding = encoding_match.group(1)
code = code.replace(encoding, encoding + "\n" + imports_block + "\n\n")
elif docstring_match:
# ドキュメントストリングの後に追加
docstring = docstring_match.group(1)
code = code.replace(docstring, docstring + "\n" + imports_block + "\n\n")
else:
# 単純に先頭に追加
code = imports_block + "\n\n" + code
return code