Yasu777 commited on
Commit
00700bc
·
verified ·
1 Parent(s): fd9f463

Create dependency_resolver.py

Browse files
Files changed (1) hide show
  1. validators/dependency_resolver.py +359 -0
validators/dependency_resolver.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import ast
3
+ from typing import Dict, Any, List, Optional, Set, Tuple
4
+
5
+ from validators.base import BaseValidator, Validator
6
+
7
+ class DependencyResolver(BaseValidator, Validator):
8
+ """ライブラリ依存性の検出と解決を行うクラス"""
9
+
10
+ def __init__(self, client=None):
11
+ """依存性解決クラスの初期化"""
12
+ super().__init__(client)
13
+ self.validation_results = {}
14
+
15
+ # 一般的なライブラリとそのインポート文のマッピング
16
+ self.common_libraries = {
17
+ # Python 標準ライブラリ
18
+ "os": "import os",
19
+ "sys": "import sys",
20
+ "re": "import re",
21
+ "math": "import math",
22
+ "random": "import random",
23
+ "datetime": "import datetime",
24
+ "time": "import time",
25
+ "json": "import json",
26
+ "csv": "import csv",
27
+ "collections": "import collections",
28
+ "functools": "import functools",
29
+ "itertools": "import itertools",
30
+ "pathlib": "import pathlib",
31
+ "typing": "import typing",
32
+ "logging": "import logging",
33
+ "argparse": "import argparse",
34
+ "subprocess": "import subprocess",
35
+ "threading": "import threading",
36
+ "multiprocessing": "import multiprocessing",
37
+ "asyncio": "import asyncio",
38
+ "io": "import io",
39
+
40
+ # データ処理と科学計算
41
+ "numpy": "import numpy as np",
42
+ "pandas": "import pandas as pd",
43
+ "scipy": "import scipy",
44
+ "matplotlib": "import matplotlib.pyplot as plt",
45
+ "seaborn": "import seaborn as sns",
46
+ "sklearn": "from sklearn import ...",
47
+ "tensorflow": "import tensorflow as tf",
48
+ "torch": "import torch",
49
+ "keras": "import keras",
50
+ "sympy": "import sympy",
51
+
52
+ # Web 関連
53
+ "requests": "import requests",
54
+ "flask": "from flask import Flask",
55
+ "django": "from django import ...",
56
+ "fastapi": "from fastapi import FastAPI",
57
+ "beautifulsoup4": "from bs4 import BeautifulSoup",
58
+ "sqlalchemy": "from sqlalchemy import ...",
59
+ "aiohttp": "import aiohttp",
60
+
61
+ # テスト関連
62
+ "pytest": "import pytest",
63
+ "unittest": "import unittest",
64
+ "mock": "from unittest import mock",
65
+
66
+ # ユーティリティ
67
+ "tqdm": "from tqdm import tqdm",
68
+ "pyyaml": "import yaml",
69
+ "pillow": "from PIL import Image",
70
+ "click": "import click",
71
+
72
+ # 特定の物理シミュレーション関連
73
+ "pymunk": "import pymunk",
74
+ "pybox2d": "import Box2D",
75
+ "mujoco": "import mujoco",
76
+ "pybullet": "import pybullet",
77
+ "opensim": "import opensim",
78
+ "vpython": "import vpython",
79
+ "fenics": "import fenics",
80
+ "dolfinx": "import dolfinx",
81
+ "pygmsh": "import pygmsh",
82
+ "meshio": "import meshio",
83
+ "pde": "import pde",
84
+ "mpmath": "import mpmath",
85
+ "dedalus": "import dedalus"
86
+ }
87
+
88
+ # Python のビルトイン関数・型リスト
89
+ self.python_builtins = set([
90
+ "abs", "all", "any", "ascii", "bin", "bool", "breakpoint", "bytearray", "bytes",
91
+ "callable", "chr", "classmethod", "compile", "complex", "delattr", "dict", "dir",
92
+ "divmod", "enumerate", "eval", "exec", "filter", "float", "format", "frozenset",
93
+ "getattr", "globals", "hasattr", "hash", "help", "hex", "id", "input", "int",
94
+ "isinstance", "issubclass", "iter", "len", "list", "locals", "map", "max", "memoryview",
95
+ "min", "next", "object", "oct", "open", "ord", "pow", "print", "property", "range",
96
+ "repr", "reversed", "round", "set", "setattr", "slice", "sorted", "staticmethod",
97
+ "str", "sum", "super", "tuple", "type", "vars", "zip", "__import__",
98
+
99
+ # 例外型
100
+ "BaseException", "Exception", "ArithmeticError", "BufferError", "LookupError",
101
+ "AssertionError", "AttributeError", "EOFError", "FloatingPointError", "GeneratorExit",
102
+ "ImportError", "ModuleNotFoundError", "IndexError", "KeyError", "KeyboardInterrupt",
103
+ "MemoryError", "NameError", "NotImplementedError", "OSError", "OverflowError",
104
+ "RecursionError", "ReferenceError", "RuntimeError", "StopIteration", "StopAsyncIteration",
105
+ "SyntaxError", "IndentationError", "TabError", "SystemError", "SystemExit", "TypeError",
106
+ "UnboundLocalError", "UnicodeError", "UnicodeEncodeError", "UnicodeDecodeError",
107
+ "UnicodeTranslateError", "ValueError", "ZeroDivisionError"
108
+ ])
109
+
110
+ async def validate(self, code, context=None):
111
+ """依存関係の解決を実行する(インターフェース実装)"""
112
+ analysis_result = self.analyze_dependencies(code)
113
+ self.validation_results = analysis_result
114
+
115
+ # 依存関係の解決が必要な場合
116
+ if analysis_result.get("missing_imports"):
117
+ resolved_code = self.resolve_dependencies(code)
118
+ return resolved_code
119
+
120
+ return code
121
+
122
+ def get_result_summary(self):
123
+ """検証結果の要約を返す(インターフェース実装)"""
124
+ if not self.validation_results:
125
+ return "依存関係の分析はまだ実行されていません。"
126
+
127
+ detected_imports = self.validation_results.get("detected_imports", [])
128
+ missing_imports = self.validation_results.get("missing_imports", [])
129
+ implicit_dependencies = self.validation_results.get("implicit_dependencies", [])
130
+
131
+ summary = f"検出されたインポート: {len(detected_imports)}個\n"
132
+ summary += f"暗黙的な依存関係: {len(implicit_dependencies)}個\n"
133
+
134
+ if missing_imports:
135
+ summary += "不足しているインポート:\n"
136
+ summary += "\n".join([f"- {imp}" for imp in missing_imports[:5]])
137
+ if len(missing_imports) > 5:
138
+ summary += f"\n...他{len(missing_imports) - 5}個"
139
+ else:
140
+ summary += "全ての依存関係が適切にインポートされています。"
141
+
142
+ return summary
143
+
144
+ def analyze_dependencies(self, code: str) -> Dict[str, Any]:
145
+ """コードの依存関係を解析する"""
146
+ result = {
147
+ "detected_imports": [],
148
+ "missing_imports": [],
149
+ "unused_imports": [],
150
+ "implicit_dependencies": [],
151
+ "import_statements": []
152
+ }
153
+
154
+ try:
155
+ # コードのASTを解析
156
+ tree = ast.parse(code)
157
+
158
+ # インポート文を抽出
159
+ for node in ast.walk(tree):
160
+ if isinstance(node, ast.Import):
161
+ for name in node.names:
162
+ result["detected_imports"].append(name.name)
163
+ result["import_statements"].append(f"import {name.name}")
164
+ elif isinstance(node, ast.ImportFrom):
165
+ module = node.module
166
+ for name in node.names:
167
+ result["detected_imports"].append(f"{module}.{name.name}")
168
+ result["import_statements"].append(f"from {module} import {name.name}")
169
+
170
+ # 暗黙的な依存関係を検出(コード内で参照されているが明示的にインポートされていない名前)
171
+ used_names = set()
172
+ for node in ast.walk(tree):
173
+ if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Load):
174
+ used_names.add(node.id)
175
+
176
+ # ビルトイン以外のトップレベルの名前を抽出
177
+ top_level_names = set()
178
+ for node in ast.walk(tree):
179
+ if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Store):
180
+ if not any(isinstance(parent, (ast.FunctionDef, ast.ClassDef)) for parent in ast.iter_child_nodes(tree)):
181
+ top_level_names.add(node.id)
182
+
183
+ # 使用されている名前から既知のライブラリを検出
184
+ for name in used_names:
185
+ # ビルトインでなく、トップレベルで定義されていない名前
186
+ if name not in self.python_builtins and name not in top_level_names:
187
+ # 名前がライブラリの一部として検出されるか確認
188
+ for lib_name, import_stmt in self.common_libraries.items():
189
+ if name == lib_name or name in ["np", "pd", "plt", "sns", "tf"] and self._is_common_alias(name, lib_name):
190
+ if lib_name not in result["detected_imports"] and f"{lib_name}." not in [d.split(".")[0] for d in result["detected_imports"]]:
191
+ result["implicit_dependencies"].append(name)
192
+ result["missing_imports"].append(self.common_libraries[lib_name])
193
+
194
+ # 特殊なライブラリ使用パターンの検出
195
+ self._detect_special_library_usage(code, result)
196
+
197
+ # 物理シミュレーション関連のパターン検出
198
+ self._detect_physics_simulation_libraries(code, result)
199
+
200
+ except SyntaxError as e:
201
+ print(f"[Error] Syntax error during dependency analysis: {str(e)}")
202
+ result["error"] = f"構文エラー: {str(e)}"
203
+ except Exception as e:
204
+ print(f"[Error] Exception during dependency analysis: {str(e)}")
205
+ result["error"] = f"解析エラー: {str(e)}"
206
+
207
+ return result
208
+
209
+ def _detect_special_library_usage(self, code: str, result: Dict[str, Any]) -> None:
210
+ """特殊なライブラリ使用パターンを検出する"""
211
+ # Numpyの数学関数パターン
212
+ numpy_math_patterns = [
213
+ r'\bsin\s*\(', r'\bcos\s*\(', r'\btan\s*\(', r'\bexp\s*\(', r'\blog\s*\(',
214
+ r'\bsqrt\s*\(', r'\barray\s*\(', r'\bmatrix\s*\(', r'\bzeros\s*\(', r'\bones\s*\(',
215
+ r'\blinspace\s*\(', r'\barange\s*\(', r'\breshape\s*\(', r'\btranspose\s*\('
216
+ ]
217
+
218
+ has_numpy_pattern = False
219
+ for pattern in numpy_math_patterns:
220
+ if re.search(pattern, code) and not re.search(r'math\.' + pattern[1:], code):
221
+ has_numpy_pattern = True
222
+ break
223
+
224
+ if has_numpy_pattern and "numpy" not in result["detected_imports"]:
225
+ result["implicit_dependencies"].append("numpy")
226
+ result["missing_imports"].append("import numpy as np")
227
+
228
+ # Pandas データフレーム操作パターン
229
+ pandas_patterns = [
230
+ r'\bDataFrame\s*\(', r'\.iloc\b', r'\.loc\b', r'\.groupby\s*\(', r'\.pivot\s*\(',
231
+ r'\.merge\s*\(', r'\.join\s*\(', r'\.concat\s*\(', r'\.read_csv\s*\(', r'\.to_csv\s*\('
232
+ ]
233
+
234
+ has_pandas_pattern = False
235
+ for pattern in pandas_patterns:
236
+ if re.search(pattern, code):
237
+ has_pandas_pattern = True
238
+ break
239
+
240
+ if has_pandas_pattern and "pandas" not in result["detected_imports"]:
241
+ result["implicit_dependencies"].append("pandas")
242
+ result["missing_imports"].append("import pandas as pd")
243
+
244
+ # Matplotlib プロット操作パターン
245
+ matplotlib_patterns = [
246
+ r'\bplot\s*\(', r'\bscatter\s*\(', r'\bfigure\s*\(', r'\bsubplot\s*\(', r'\bbar\s*\(',
247
+ r'\bhistogram\s*\(', r'\blegend\s*\(', r'\btitle\s*\(', r'\bxlabel\s*\(', r'\bylabel\s*\('
248
+ ]
249
+
250
+ has_matplotlib_pattern = False
251
+ for pattern in matplotlib_patterns:
252
+ if re.search(pattern, code) and not re.search(r'\bplt\.' + pattern[1:], code):
253
+ has_matplotlib_pattern = True
254
+ break
255
+
256
+ if has_matplotlib_pattern and "matplotlib" not in result["detected_imports"]:
257
+ result["implicit_dependencies"].append("matplotlib")
258
+ result["missing_imports"].append("import matplotlib.pyplot as plt")
259
+
260
+ def _detect_physics_simulation_libraries(self, code: str, result: Dict[str, Any]) -> None:
261
+ """物理シミュレーション関連のライブラリ使用を検出する"""
262
+ # 物理シミュレーション関連のパターン
263
+ physics_patterns = {
264
+ "pymunk": [r'\bBody\s*\(', r'\bSpace\s*\(', r'\bShape\s*\(', r'\bCircle\s*\(', r'\bSegment\s*\('],
265
+ "pybox2d": [r'\bWorld\s*\(', r'\bBox2D\b', r'\bBody\s*\(', r'\bJoint\s*\('],
266
+ "vpython": [r'\bsphere\s*\(', r'\bbox\s*\(', r'\bcylinder\s*\(', r'\bvector\s*\(', r'\brate\s*\('],
267
+ "mujoco": [r'\bmj_\w+\s*\(', r'\bmjData\b', r'\bmjModel\b'],
268
+ "pybullet": [r'\bpybullet\b', r'\bcreateCylinder\s*\(', r'\bcreateBox\s*\(', r'\bgetBasePositionAndOrientation\s*\('],
269
+ "fenics": [r'\bMesh\s*\(', r'\bFunctionSpace\s*\(', r'\bFunction\s*\(', r'\bsplit\s*\(', r'\bproject\s*\(']
270
+ }
271
+
272
+ for lib_name, patterns in physics_patterns.items():
273
+ for pattern in patterns:
274
+ if re.search(pattern, code):
275
+ if lib_name not in result["detected_imports"]:
276
+ result["implicit_dependencies"].append(lib_name)
277
+ result["missing_imports"].append(self.common_libraries.get(lib_name, f"import {lib_name}"))
278
+ break
279
+
280
+ def _is_common_alias(self, alias: str, library: str) -> bool:
281
+ """一般的なエイリアスかどうかをチェックする"""
282
+ common_aliases = {
283
+ "np": "numpy",
284
+ "pd": "pandas",
285
+ "plt": "matplotlib.pyplot",
286
+ "sns": "seaborn",
287
+ "tf": "tensorflow"
288
+ }
289
+ return common_aliases.get(alias) == library
290
+
291
+ def resolve_dependencies(self, code: str) -> str:
292
+ """コードの依存関係を解決し、必要なインポート文を追加する"""
293
+ # まず依存関係を分析
294
+ analysis = self.analyze_dependencies(code)
295
+
296
+ # エラーがある場合は元のコードを返す
297
+ if "error" in analysis:
298
+ print(f"[Warning] Dependency analysis error: {analysis['error']}")
299
+ return code
300
+
301
+ # インポート文の追加が必要ない場合
302
+ if not analysis["missing_imports"]:
303
+ return code
304
+
305
+ # コードの先頭にインポート文を追加
306
+ missing_imports = list(set(analysis["missing_imports"])) # 重複を削除
307
+
308
+ # 標準ライブラリ、サードパー���ィライブラリ、アプリケーションコードの順に並べる
309
+ std_libs = []
310
+ third_party_libs = []
311
+ app_libs = []
312
+
313
+ for imp in missing_imports:
314
+ if imp.startswith(("import os", "import sys", "import re", "import math", "import datetime", "import collections", "import functools", "import itertools", "import pathlib", "import typing", "import json", "import csv")):
315
+ std_libs.append(imp)
316
+ elif imp.startswith(("import numpy", "import pandas", "import matplotlib", "import scipy", "import sklearn", "import tensorflow", "import torch", "import keras", "import requests", "import flask", "import django", "import fastapi", "import bs4", "import sqlalchemy")):
317
+ third_party_libs.append(imp)
318
+ else:
319
+ app_libs.append(imp)
320
+
321
+ # 整理したインポート文を結合
322
+ imports_block = "\n".join(std_libs)
323
+ if std_libs and (third_party_libs or app_libs):
324
+ imports_block += "\n"
325
+ imports_block += "\n".join(third_party_libs)
326
+ if third_party_libs and app_libs:
327
+ imports_block += "\n"
328
+ imports_block += "\n".join(app_libs)
329
+
330
+ # 既存のインポート文を検出
331
+ existing_imports_match = re.search(r'^((?:(?:import|from)\s+.*?\n)+)', code, re.MULTILINE)
332
+
333
+ if existing_imports_match:
334
+ # 既存のインポート文の後に追加
335
+ existing_imports = existing_imports_match.group(1)
336
+ code = code.replace(existing_imports, existing_imports + "\n" + imports_block + "\n")
337
+ else:
338
+ # コードの先頭に追加
339
+ shebang_match = re.search(r'^(#!.*?\n)', code)
340
+ encoding_match = re.search(r'^(# -*coding[=:]\s*[-\w.]+.*?\n)', code)
341
+ docstring_match = re.search(r'^(""".*?"""\n)', code, re.DOTALL)
342
+
343
+ if shebang_match:
344
+ # シェバンの後に追加
345
+ shebang = shebang_match.group(1)
346
+ code = code.replace(shebang, shebang + "\n" + imports_block + "\n\n")
347
+ elif encoding_match:
348
+ # エンコーディング宣言の後に追加
349
+ encoding = encoding_match.group(1)
350
+ code = code.replace(encoding, encoding + "\n" + imports_block + "\n\n")
351
+ elif docstring_match:
352
+ # ドキュメントストリングの後に追加
353
+ docstring = docstring_match.group(1)
354
+ code = code.replace(docstring, docstring + "\n" + imports_block + "\n\n")
355
+ else:
356
+ # 単純に先頭に追加
357
+ code = imports_block + "\n\n" + code
358
+
359
+ return code