File size: 17,782 Bytes
00700bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
import re
import ast
from typing import Dict, Any, List, Optional, Set, Tuple

from validators.base import BaseValidator, Validator

class DependencyResolver(BaseValidator, Validator):
    """ライブラリ依存性の検出と解決を行うクラス"""
    
    def __init__(self, client=None):
        """依存性解決クラスの初期化"""
        super().__init__(client)
        self.validation_results = {}
        
        # 一般的なライブラリとそのインポート文のマッピング
        self.common_libraries = {
            # Python 標準ライブラリ
            "os": "import os",
            "sys": "import sys",
            "re": "import re",
            "math": "import math",
            "random": "import random",
            "datetime": "import datetime",
            "time": "import time",
            "json": "import json",
            "csv": "import csv",
            "collections": "import collections",
            "functools": "import functools",
            "itertools": "import itertools",
            "pathlib": "import pathlib",
            "typing": "import typing",
            "logging": "import logging",
            "argparse": "import argparse",
            "subprocess": "import subprocess",
            "threading": "import threading",
            "multiprocessing": "import multiprocessing",
            "asyncio": "import asyncio",
            "io": "import io",
            
            # データ処理と科学計算
            "numpy": "import numpy as np",
            "pandas": "import pandas as pd",
            "scipy": "import scipy",
            "matplotlib": "import matplotlib.pyplot as plt",
            "seaborn": "import seaborn as sns",
            "sklearn": "from sklearn import ...",
            "tensorflow": "import tensorflow as tf",
            "torch": "import torch",
            "keras": "import keras",
            "sympy": "import sympy",
            
            # Web 関連
            "requests": "import requests",
            "flask": "from flask import Flask",
            "django": "from django import ...",
            "fastapi": "from fastapi import FastAPI",
            "beautifulsoup4": "from bs4 import BeautifulSoup",
            "sqlalchemy": "from sqlalchemy import ...",
            "aiohttp": "import aiohttp",
            
            # テスト関連
            "pytest": "import pytest",
            "unittest": "import unittest",
            "mock": "from unittest import mock",
            
            # ユーティリティ
            "tqdm": "from tqdm import tqdm",
            "pyyaml": "import yaml",
            "pillow": "from PIL import Image",
            "click": "import click",
            
            # 特定の物理シミュレーション関連
            "pymunk": "import pymunk",
            "pybox2d": "import Box2D",
            "mujoco": "import mujoco",
            "pybullet": "import pybullet",
            "opensim": "import opensim",
            "vpython": "import vpython",
            "fenics": "import fenics",
            "dolfinx": "import dolfinx",
            "pygmsh": "import pygmsh",
            "meshio": "import meshio",
            "pde": "import pde",
            "mpmath": "import mpmath",
            "dedalus": "import dedalus"
        }
        
        # Python のビルトイン関数・型リスト
        self.python_builtins = set([
            "abs", "all", "any", "ascii", "bin", "bool", "breakpoint", "bytearray", "bytes",
            "callable", "chr", "classmethod", "compile", "complex", "delattr", "dict", "dir",
            "divmod", "enumerate", "eval", "exec", "filter", "float", "format", "frozenset",
            "getattr", "globals", "hasattr", "hash", "help", "hex", "id", "input", "int",
            "isinstance", "issubclass", "iter", "len", "list", "locals", "map", "max", "memoryview",
            "min", "next", "object", "oct", "open", "ord", "pow", "print", "property", "range",
            "repr", "reversed", "round", "set", "setattr", "slice", "sorted", "staticmethod",
            "str", "sum", "super", "tuple", "type", "vars", "zip", "__import__",
            
            # 例外型
            "BaseException", "Exception", "ArithmeticError", "BufferError", "LookupError",
            "AssertionError", "AttributeError", "EOFError", "FloatingPointError", "GeneratorExit",
            "ImportError", "ModuleNotFoundError", "IndexError", "KeyError", "KeyboardInterrupt",
            "MemoryError", "NameError", "NotImplementedError", "OSError", "OverflowError",
            "RecursionError", "ReferenceError", "RuntimeError", "StopIteration", "StopAsyncIteration",
            "SyntaxError", "IndentationError", "TabError", "SystemError", "SystemExit", "TypeError",
            "UnboundLocalError", "UnicodeError", "UnicodeEncodeError", "UnicodeDecodeError",
            "UnicodeTranslateError", "ValueError", "ZeroDivisionError"
        ])
    
    async def validate(self, code, context=None):
        """依存関係の解決を実行する(インターフェース実装)"""
        analysis_result = self.analyze_dependencies(code)
        self.validation_results = analysis_result
        
        # 依存関係の解決が必要な場合
        if analysis_result.get("missing_imports"):
            resolved_code = self.resolve_dependencies(code)
            return resolved_code
        
        return code
    
    def get_result_summary(self):
        """検証結果の要約を返す(インターフェース実装)"""
        if not self.validation_results:
            return "依存関係の分析はまだ実行されていません。"
        
        detected_imports = self.validation_results.get("detected_imports", [])
        missing_imports = self.validation_results.get("missing_imports", [])
        implicit_dependencies = self.validation_results.get("implicit_dependencies", [])
        
        summary = f"検出されたインポート: {len(detected_imports)}個\n"
        summary += f"暗黙的な依存関係: {len(implicit_dependencies)}個\n"
        
        if missing_imports:
            summary += "不足しているインポート:\n"
            summary += "\n".join([f"- {imp}" for imp in missing_imports[:5]])
            if len(missing_imports) > 5:
                summary += f"\n...他{len(missing_imports) - 5}個"
        else:
            summary += "全ての依存関係が適切にインポートされています。"
        
        return summary
    
    def analyze_dependencies(self, code: str) -> Dict[str, Any]:
        """コードの依存関係を解析する"""
        result = {
            "detected_imports": [],
            "missing_imports": [],
            "unused_imports": [],
            "implicit_dependencies": [],
            "import_statements": []
        }
        
        try:
            # コードのASTを解析
            tree = ast.parse(code)
            
            # インポート文を抽出
            for node in ast.walk(tree):
                if isinstance(node, ast.Import):
                    for name in node.names:
                        result["detected_imports"].append(name.name)
                        result["import_statements"].append(f"import {name.name}")
                elif isinstance(node, ast.ImportFrom):
                    module = node.module
                    for name in node.names:
                        result["detected_imports"].append(f"{module}.{name.name}")
                        result["import_statements"].append(f"from {module} import {name.name}")
            
            # 暗黙的な依存関係を検出(コード内で参照されているが明示的にインポートされていない名前)
            used_names = set()
            for node in ast.walk(tree):
                if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Load):
                    used_names.add(node.id)
            
            # ビルトイン以外のトップレベルの名前を抽出
            top_level_names = set()
            for node in ast.walk(tree):
                if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Store):
                    if not any(isinstance(parent, (ast.FunctionDef, ast.ClassDef)) for parent in ast.iter_child_nodes(tree)):
                        top_level_names.add(node.id)
            
            # 使用されている名前から既知のライブラリを検出
            for name in used_names:
                # ビルトインでなく、トップレベルで定義されていない名前
                if name not in self.python_builtins and name not in top_level_names:
                    # 名前がライブラリの一部として検出されるか確認
                    for lib_name, import_stmt in self.common_libraries.items():
                        if name == lib_name or name in ["np", "pd", "plt", "sns", "tf"] and self._is_common_alias(name, lib_name):
                            if lib_name not in result["detected_imports"] and f"{lib_name}." not in [d.split(".")[0] for d in result["detected_imports"]]:
                                result["implicit_dependencies"].append(name)
                                result["missing_imports"].append(self.common_libraries[lib_name])
            
            # 特殊なライブラリ使用パターンの検出
            self._detect_special_library_usage(code, result)
            
            # 物理シミュレーション関連のパターン検出
            self._detect_physics_simulation_libraries(code, result)
            
        except SyntaxError as e:
            print(f"[Error] Syntax error during dependency analysis: {str(e)}")
            result["error"] = f"構文エラー: {str(e)}"
        except Exception as e:
            print(f"[Error] Exception during dependency analysis: {str(e)}")
            result["error"] = f"解析エラー: {str(e)}"
        
        return result
    
    def _detect_special_library_usage(self, code: str, result: Dict[str, Any]) -> None:
        """特殊なライブラリ使用パターンを検出する"""
        # Numpyの数学関数パターン
        numpy_math_patterns = [
            r'\bsin\s*\(', r'\bcos\s*\(', r'\btan\s*\(', r'\bexp\s*\(', r'\blog\s*\(', 
            r'\bsqrt\s*\(', r'\barray\s*\(', r'\bmatrix\s*\(', r'\bzeros\s*\(', r'\bones\s*\(',
            r'\blinspace\s*\(', r'\barange\s*\(', r'\breshape\s*\(', r'\btranspose\s*\('
        ]
        
        has_numpy_pattern = False
        for pattern in numpy_math_patterns:
            if re.search(pattern, code) and not re.search(r'math\.' + pattern[1:], code):
                has_numpy_pattern = True
                break
        
        if has_numpy_pattern and "numpy" not in result["detected_imports"]:
            result["implicit_dependencies"].append("numpy")
            result["missing_imports"].append("import numpy as np")
        
        # Pandas データフレーム操作パターン
        pandas_patterns = [
            r'\bDataFrame\s*\(', r'\.iloc\b', r'\.loc\b', r'\.groupby\s*\(', r'\.pivot\s*\(',
            r'\.merge\s*\(', r'\.join\s*\(', r'\.concat\s*\(', r'\.read_csv\s*\(', r'\.to_csv\s*\('
        ]
        
        has_pandas_pattern = False
        for pattern in pandas_patterns:
            if re.search(pattern, code):
                has_pandas_pattern = True
                break
        
        if has_pandas_pattern and "pandas" not in result["detected_imports"]:
            result["implicit_dependencies"].append("pandas")
            result["missing_imports"].append("import pandas as pd")
        
        # Matplotlib プロット操作パターン
        matplotlib_patterns = [
            r'\bplot\s*\(', r'\bscatter\s*\(', r'\bfigure\s*\(', r'\bsubplot\s*\(', r'\bbar\s*\(',
            r'\bhistogram\s*\(', r'\blegend\s*\(', r'\btitle\s*\(', r'\bxlabel\s*\(', r'\bylabel\s*\('
        ]
        
        has_matplotlib_pattern = False
        for pattern in matplotlib_patterns:
            if re.search(pattern, code) and not re.search(r'\bplt\.' + pattern[1:], code):
                has_matplotlib_pattern = True
                break
        
        if has_matplotlib_pattern and "matplotlib" not in result["detected_imports"]:
            result["implicit_dependencies"].append("matplotlib")
            result["missing_imports"].append("import matplotlib.pyplot as plt")
    
    def _detect_physics_simulation_libraries(self, code: str, result: Dict[str, Any]) -> None:
        """物理シミュレーション関連のライブラリ使用を検出する"""
        # 物理シミュレーション関連のパターン
        physics_patterns = {
            "pymunk": [r'\bBody\s*\(', r'\bSpace\s*\(', r'\bShape\s*\(', r'\bCircle\s*\(', r'\bSegment\s*\('],
            "pybox2d": [r'\bWorld\s*\(', r'\bBox2D\b', r'\bBody\s*\(', r'\bJoint\s*\('],
            "vpython": [r'\bsphere\s*\(', r'\bbox\s*\(', r'\bcylinder\s*\(', r'\bvector\s*\(', r'\brate\s*\('],
            "mujoco": [r'\bmj_\w+\s*\(', r'\bmjData\b', r'\bmjModel\b'],
            "pybullet": [r'\bpybullet\b', r'\bcreateCylinder\s*\(', r'\bcreateBox\s*\(', r'\bgetBasePositionAndOrientation\s*\('],
            "fenics": [r'\bMesh\s*\(', r'\bFunctionSpace\s*\(', r'\bFunction\s*\(', r'\bsplit\s*\(', r'\bproject\s*\(']
        }
        
        for lib_name, patterns in physics_patterns.items():
            for pattern in patterns:
                if re.search(pattern, code):
                    if lib_name not in result["detected_imports"]:
                        result["implicit_dependencies"].append(lib_name)
                        result["missing_imports"].append(self.common_libraries.get(lib_name, f"import {lib_name}"))
                    break
    
    def _is_common_alias(self, alias: str, library: str) -> bool:
        """一般的なエイリアスかどうかをチェックする"""
        common_aliases = {
            "np": "numpy",
            "pd": "pandas",
            "plt": "matplotlib.pyplot",
            "sns": "seaborn",
            "tf": "tensorflow"
        }
        return common_aliases.get(alias) == library
    
    def resolve_dependencies(self, code: str) -> str:
        """コードの依存関係を解決し、必要なインポート文を追加する"""
        # まず依存関係を分析
        analysis = self.analyze_dependencies(code)
        
        # エラーがある場合は元のコードを返す
        if "error" in analysis:
            print(f"[Warning] Dependency analysis error: {analysis['error']}")
            return code
        
        # インポート文の追加が必要ない場合
        if not analysis["missing_imports"]:
            return code
        
        # コードの先頭にインポート文を追加
        missing_imports = list(set(analysis["missing_imports"]))  # 重複を削除
        
        # 標準ライブラリ、サードパーティライブラリ、アプリケーションコードの順に並べる
        std_libs = []
        third_party_libs = []
        app_libs = []
        
        for imp in missing_imports:
            if imp.startswith(("import os", "import sys", "import re", "import math", "import datetime", "import collections", "import functools", "import itertools", "import pathlib", "import typing", "import json", "import csv")):
                std_libs.append(imp)
            elif imp.startswith(("import numpy", "import pandas", "import matplotlib", "import scipy", "import sklearn", "import tensorflow", "import torch", "import keras", "import requests", "import flask", "import django", "import fastapi", "import bs4", "import sqlalchemy")):
                third_party_libs.append(imp)
            else:
                app_libs.append(imp)
        
        # 整理したインポート文を結合
        imports_block = "\n".join(std_libs)
        if std_libs and (third_party_libs or app_libs):
            imports_block += "\n"
        imports_block += "\n".join(third_party_libs)
        if third_party_libs and app_libs:
            imports_block += "\n"
        imports_block += "\n".join(app_libs)
        
        # 既存のインポート文を検出
        existing_imports_match = re.search(r'^((?:(?:import|from)\s+.*?\n)+)', code, re.MULTILINE)
        
        if existing_imports_match:
            # 既存のインポート文の後に追加
            existing_imports = existing_imports_match.group(1)
            code = code.replace(existing_imports, existing_imports + "\n" + imports_block + "\n")
        else:
            # コードの先頭に追加
            shebang_match = re.search(r'^(#!.*?\n)', code)
            encoding_match = re.search(r'^(# -*coding[=:]\s*[-\w.]+.*?\n)', code)
            docstring_match = re.search(r'^(""".*?"""\n)', code, re.DOTALL)
            
            if shebang_match:
                # シェバンの後に追加
                shebang = shebang_match.group(1)
                code = code.replace(shebang, shebang + "\n" + imports_block + "\n\n")
            elif encoding_match:
                # エンコーディング宣言の後に追加
                encoding = encoding_match.group(1)
                code = code.replace(encoding, encoding + "\n" + imports_block + "\n\n")
            elif docstring_match:
                # ドキュメントストリングの後に追加
                docstring = docstring_match.group(1)
                code = code.replace(docstring, docstring + "\n" + imports_block + "\n\n")
            else:
                # 単純に先頭に追加
                code = imports_block + "\n\n" + code
        
        return code