|
|
import ast |
|
|
import os |
|
|
import logging |
|
|
import tempfile |
|
|
from typing import List, Dict, Any, Tuple, Optional |
|
|
from clang import cindex |
|
|
import javalang |
|
|
import javalang.tree as T |
|
|
import esprima |
|
|
from bs4 import BeautifulSoup |
|
|
import tree_sitter_rust as ts_rust |
|
|
from tree_sitter import Language, Parser |
|
|
import re |
|
|
from .utils.path_utils import generate_entity_aliases |
|
|
|
|
|
|
|
|
|
|
|
LOGGER_NAME = "AST_ENTITY_EXTRACTOR" |
|
|
logger = logging.getLogger(LOGGER_NAME) |
|
|
|
|
|
|
|
|
class BaseASTEntityExtractor: |
|
|
def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]: |
|
|
""" |
|
|
Extract entities from source code. |
|
|
|
|
|
Args: |
|
|
code: Source code as string |
|
|
file_path: Optional path to the source file (for better context and include resolution) |
|
|
|
|
|
Returns: |
|
|
Tuple of (declared_entities, called_entities) |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
def reset(self) -> None: |
|
|
""" |
|
|
Reset internal state so the extractor instance can be reused. |
|
|
Concrete extractors should override this to clear their buffers. |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
class HTMLEntityExtractor(BaseASTEntityExtractor): |
|
|
""" |
|
|
Hybrid HTML AST-based entity extractor. |
|
|
|
|
|
Responsibilities: |
|
|
β’ Parse HTML into a tree |
|
|
β’ Extract declared DOM entities (ids, names, classes) |
|
|
β’ Extract JavaScript calls from inline event handlers |
|
|
β’ Extract JS entities from <script> tags |
|
|
β’ Integrate cleanly with the hybrid AST graph linker |
|
|
""" |
|
|
|
|
|
EVENT_ATTR_PREFIX = "on" |
|
|
|
|
|
def __init__(self): |
|
|
self.js_extractor = JavaScriptEntityExtractor() |
|
|
self.reset() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset(self): |
|
|
self.declared_entities: List[Dict[str, str]] = [] |
|
|
self.called_entities: List[str] = [] |
|
|
|
|
|
def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, str]], List[str]]: |
|
|
"""Main entry point: parse HTML and extract entities.""" |
|
|
self.reset() |
|
|
try: |
|
|
soup = BeautifulSoup(code, "html.parser") |
|
|
except Exception as e: |
|
|
print(f"[HTMLEntityExtractor] Parsing error: {e}") |
|
|
return [], [] |
|
|
|
|
|
|
|
|
for tag in soup.find_all(True): |
|
|
self._handle_tag_declaration(tag) |
|
|
self._handle_event_attributes(tag) |
|
|
|
|
|
|
|
|
for script in soup.find_all("script"): |
|
|
self._handle_script(script) |
|
|
|
|
|
|
|
|
self.declared_entities = self._deduplicate_dicts(self.declared_entities) |
|
|
self.called_entities = self._deduplicate_list(self.called_entities) |
|
|
|
|
|
return self.declared_entities, self.called_entities |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _handle_tag_declaration(self, tag): |
|
|
"""Extract declared DOM elements (id, name, class).""" |
|
|
if tag.has_attr("id"): |
|
|
self.declared_entities.append({"name": tag["id"], "type": "element"}) |
|
|
|
|
|
if tag.has_attr("name"): |
|
|
self.declared_entities.append({"name": tag["name"], "type": "element"}) |
|
|
|
|
|
if tag.has_attr("class"): |
|
|
classes = tag["class"] |
|
|
if isinstance(classes, list): |
|
|
for c in classes: |
|
|
self.declared_entities.append({"name": c, "type": "class"}) |
|
|
elif isinstance(classes, str): |
|
|
self.declared_entities.append({"name": classes, "type": "class"}) |
|
|
|
|
|
def _handle_event_attributes(self, tag): |
|
|
"""Extract JS calls from inline event attributes.""" |
|
|
if not self.js_extractor: |
|
|
return |
|
|
for attr, value in tag.attrs.items(): |
|
|
if attr.lower().startswith(self.EVENT_ATTR_PREFIX) and isinstance(value, str): |
|
|
try: |
|
|
_, called = self.js_extractor.extract_entities(value) |
|
|
self.called_entities.extend(called) |
|
|
except Exception as e: |
|
|
print(f"[HTMLEntityExtractor] JS parse error in {attr}: {e}") |
|
|
|
|
|
def _handle_script(self, script): |
|
|
"""Extract JS entities from <script> blocks or src attributes.""" |
|
|
if script.has_attr("src"): |
|
|
src = script["src"] |
|
|
self.called_entities.append(src) |
|
|
return |
|
|
|
|
|
if not self.js_extractor: |
|
|
return |
|
|
|
|
|
js_code = (script.string or "").strip() |
|
|
if js_code: |
|
|
try: |
|
|
declared, called = self.js_extractor.extract_entities(js_code) |
|
|
self.declared_entities.extend(declared) |
|
|
self.called_entities.extend(called) |
|
|
except Exception as e: |
|
|
print(f"[HTMLEntityExtractor] JS parse error in <script>: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def _deduplicate_dicts(dicts: List[Dict]) -> List[Dict]: |
|
|
seen = set() |
|
|
result = [] |
|
|
for d in dicts: |
|
|
key = tuple(sorted(d.items())) |
|
|
if key not in seen: |
|
|
seen.add(key) |
|
|
result.append(d) |
|
|
return result |
|
|
|
|
|
@staticmethod |
|
|
def _deduplicate_list(items: List[str]) -> List[str]: |
|
|
seen = set() |
|
|
result = [] |
|
|
for i in items: |
|
|
if i not in seen: |
|
|
seen.add(i) |
|
|
result.append(i) |
|
|
return result |
|
|
|
|
|
|
|
|
class JavaEntityExtractor(BaseASTEntityExtractor): |
|
|
""" |
|
|
Extract declared and called entities from Java code using javalang. |
|
|
Produces the same (declared_entities, called_entities) structure as other extractors. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.reset() |
|
|
|
|
|
def reset(self) -> None: |
|
|
self.declared_entities: List[Dict[str, Any]] = [] |
|
|
self.called_entities: List[str] = [] |
|
|
self.current_package: Optional[str] = None |
|
|
self.scope_stack: List[str] = [] |
|
|
self.api_endpoints: List[Dict[str, Any]] = [] |
|
|
self.current_class_base_path: Optional[str] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _qualified(self, name: str) -> str: |
|
|
if not name: |
|
|
return "" |
|
|
scope = "::".join(self.scope_stack) |
|
|
return f"{scope}::{name}" if scope else name |
|
|
|
|
|
def _walk_type(self, t): |
|
|
"""Return string representation of a type node.""" |
|
|
if not t: |
|
|
return "unknown" |
|
|
if isinstance(t, str): |
|
|
return t |
|
|
if hasattr(t, "name"): |
|
|
name = t.name |
|
|
if getattr(t, "arguments", None): |
|
|
args = [self._walk_type(a.type) for a in t.arguments if hasattr(a, "type")] |
|
|
name += "<" + ", ".join(args) + ">" |
|
|
return name |
|
|
return "unknown" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]: |
|
|
self.reset() |
|
|
|
|
|
try: |
|
|
tree = javalang.parse.parse(code) |
|
|
except javalang.parser.JavaSyntaxError as e: |
|
|
logger.error(f"Syntax error in Java code: {e}") |
|
|
return [], [] |
|
|
except Exception as e: |
|
|
logger.error(f"Error parsing Java code: {e}", exc_info=True) |
|
|
return [], [] |
|
|
|
|
|
|
|
|
if tree.package: |
|
|
self.current_package = tree.package.name |
|
|
|
|
|
|
|
|
for imp in tree.imports: |
|
|
self.called_entities.append(imp.path) |
|
|
|
|
|
|
|
|
for type_decl in tree.types: |
|
|
self._visit_type(type_decl) |
|
|
|
|
|
|
|
|
seen_decl = set() |
|
|
unique_declared = [] |
|
|
for e in self.declared_entities: |
|
|
key = (e.get("name"), e.get("type"), e.get("dtype")) |
|
|
if key not in seen_decl: |
|
|
unique_declared.append(e) |
|
|
seen_decl.add(key) |
|
|
|
|
|
unique_called = list(dict.fromkeys(self.called_entities)) |
|
|
return unique_declared, unique_called |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _visit_type(self, node): |
|
|
if isinstance(node, javalang.tree.ClassDeclaration): |
|
|
self._visit_class(node) |
|
|
elif isinstance(node, javalang.tree.InterfaceDeclaration): |
|
|
self._visit_interface(node) |
|
|
elif isinstance(node, javalang.tree.EnumDeclaration): |
|
|
self._visit_enum(node) |
|
|
|
|
|
def _visit_class(self, node): |
|
|
full_name = node.name |
|
|
if self.current_package: |
|
|
full_name = f"{self.current_package}.{node.name}" |
|
|
qualified = self._qualified(full_name) |
|
|
|
|
|
self.declared_entities.append({"name": qualified, "type": "class"}) |
|
|
|
|
|
|
|
|
old_base_path = self.current_class_base_path |
|
|
if node.annotations: |
|
|
for annotation in node.annotations: |
|
|
if annotation.name in {'RestController', 'Controller'}: |
|
|
|
|
|
pass |
|
|
elif annotation.name == 'RequestMapping': |
|
|
|
|
|
self.current_class_base_path = self._extract_path_from_annotation(annotation) |
|
|
|
|
|
|
|
|
if node.extends: |
|
|
self.called_entities.append(self._walk_type(node.extends)) |
|
|
for impl in node.implements or []: |
|
|
self.called_entities.append(self._walk_type(impl)) |
|
|
|
|
|
self.scope_stack.append(full_name) |
|
|
for member in node.body: |
|
|
self._visit_member(member) |
|
|
self.scope_stack.pop() |
|
|
|
|
|
|
|
|
self.current_class_base_path = old_base_path |
|
|
|
|
|
def _visit_interface(self, node): |
|
|
full_name = node.name |
|
|
if self.current_package: |
|
|
full_name = f"{self.current_package}.{node.name}" |
|
|
qualified = self._qualified(full_name) |
|
|
self.declared_entities.append({"name": qualified, "type": "interface"}) |
|
|
|
|
|
for impl in node.extends or []: |
|
|
self.called_entities.append(self._walk_type(impl)) |
|
|
|
|
|
self.scope_stack.append(full_name) |
|
|
for member in node.body: |
|
|
self._visit_member(member) |
|
|
self.scope_stack.pop() |
|
|
|
|
|
def _visit_enum(self, node): |
|
|
full_name = node.name |
|
|
if self.current_package: |
|
|
full_name = f"{self.current_package}.{node.name}" |
|
|
qualified = self._qualified(full_name) |
|
|
self.declared_entities.append({"name": qualified, "type": "enum"}) |
|
|
|
|
|
def _visit_member(self, node): |
|
|
|
|
|
|
|
|
if isinstance(node, T.MethodDeclaration): |
|
|
method_name = self._qualified(node.name) |
|
|
|
|
|
|
|
|
api_info = self._extract_api_endpoint_from_annotations(node) |
|
|
if api_info: |
|
|
self.declared_entities.append({ |
|
|
"name": method_name, |
|
|
"type": "api_endpoint", |
|
|
"endpoint": api_info.get("endpoint"), |
|
|
"methods": api_info.get("methods") |
|
|
}) |
|
|
self.api_endpoints.append({**api_info, "function": method_name}) |
|
|
else: |
|
|
self.declared_entities.append({"name": method_name, "type": "method"}) |
|
|
|
|
|
for param in node.parameters: |
|
|
ptype = self._walk_type(param.type) |
|
|
pname = f"{method_name}.{param.name}" |
|
|
self.declared_entities.append({ |
|
|
"name": pname, |
|
|
"type": "variable", |
|
|
"dtype": ptype |
|
|
}) |
|
|
|
|
|
|
|
|
if node.body: |
|
|
self._find_calls(node.body) |
|
|
|
|
|
|
|
|
elif isinstance(node, T.ConstructorDeclaration): |
|
|
ctor_name = self._qualified(node.name) |
|
|
self.declared_entities.append({"name": ctor_name, "type": "constructor"}) |
|
|
for param in node.parameters: |
|
|
ptype = self._walk_type(param.type) |
|
|
pname = f"{ctor_name}.{param.name}" |
|
|
self.declared_entities.append({ |
|
|
"name": pname, |
|
|
"type": "variable", |
|
|
"dtype": ptype |
|
|
}) |
|
|
if node.body: |
|
|
self._find_calls(node.body) |
|
|
|
|
|
|
|
|
elif isinstance(node, T.FieldDeclaration): |
|
|
dtype = self._walk_type(node.type) |
|
|
for decl in node.declarators: |
|
|
var_name = self._qualified(decl.name) |
|
|
self.declared_entities.append({ |
|
|
"name": var_name, |
|
|
"type": "variable", |
|
|
"dtype": dtype |
|
|
}) |
|
|
|
|
|
|
|
|
elif isinstance(node, (T.ClassDeclaration, T.InterfaceDeclaration)): |
|
|
self._visit_type(node) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _extract_api_endpoint_from_annotations(self, method) -> Optional[Dict[str, Any]]: |
|
|
""" |
|
|
Extract API endpoint information from Spring Boot method annotations. |
|
|
Handles: @GetMapping, @PostMapping, @RequestMapping, etc. |
|
|
""" |
|
|
if not method.annotations: |
|
|
return None |
|
|
|
|
|
for annotation in method.annotations: |
|
|
annotation_name = annotation.name |
|
|
|
|
|
if annotation_name in {'GetMapping', 'PostMapping', 'PutMapping', 'PatchMapping', 'DeleteMapping'}: |
|
|
|
|
|
http_method = annotation_name.replace('Mapping', '').upper() |
|
|
path = self._extract_path_from_annotation(annotation) |
|
|
|
|
|
if path: |
|
|
|
|
|
full_path = self._combine_paths(self.current_class_base_path, path) |
|
|
return { |
|
|
"endpoint": full_path, |
|
|
"methods": [http_method], |
|
|
"type": "api_endpoint_definition" |
|
|
} |
|
|
|
|
|
elif annotation_name == 'RequestMapping': |
|
|
|
|
|
path = self._extract_path_from_annotation(annotation) |
|
|
methods = self._extract_methods_from_annotation(annotation) |
|
|
|
|
|
if path: |
|
|
full_path = self._combine_paths(self.current_class_base_path, path) |
|
|
return { |
|
|
"endpoint": full_path, |
|
|
"methods": methods if methods else ['GET'], |
|
|
"type": "api_endpoint_definition" |
|
|
} |
|
|
|
|
|
return None |
|
|
|
|
|
def _extract_path_from_annotation(self, annotation) -> Optional[str]: |
|
|
"""Extract path/value from Spring annotation.""" |
|
|
if not annotation.element: |
|
|
return None |
|
|
|
|
|
|
|
|
if isinstance(annotation.element, T.Literal): |
|
|
return annotation.element.value.strip('"') |
|
|
|
|
|
|
|
|
if isinstance(annotation.element, list): |
|
|
for elem in annotation.element: |
|
|
if isinstance(elem, T.ElementValuePair): |
|
|
if elem.name in {'value', 'path'}: |
|
|
if isinstance(elem.value, T.Literal): |
|
|
return elem.value.value.strip('"') |
|
|
elif isinstance(elem.value, T.ElementArrayValue): |
|
|
|
|
|
if elem.value.values: |
|
|
first_val = elem.value.values[0] |
|
|
if isinstance(first_val, T.Literal): |
|
|
return first_val.value.strip('"') |
|
|
|
|
|
return None |
|
|
|
|
|
def _extract_methods_from_annotation(self, annotation) -> List[str]: |
|
|
"""Extract HTTP methods from @RequestMapping annotation.""" |
|
|
methods = [] |
|
|
|
|
|
if isinstance(annotation.element, list): |
|
|
for elem in annotation.element: |
|
|
if isinstance(elem, T.ElementValuePair): |
|
|
if elem.name == 'method': |
|
|
|
|
|
if hasattr(elem.value, 'member'): |
|
|
|
|
|
methods.append(elem.value.member) |
|
|
elif isinstance(elem.value, T.ElementArrayValue): |
|
|
|
|
|
for val in elem.value.values: |
|
|
if hasattr(val, 'member'): |
|
|
methods.append(val.member) |
|
|
|
|
|
return methods |
|
|
|
|
|
def _combine_paths(self, base_path: Optional[str], path: str) -> str: |
|
|
"""Combine base path from class annotation with method path.""" |
|
|
if not base_path: |
|
|
return path |
|
|
|
|
|
|
|
|
base = base_path.rstrip('/') |
|
|
path = path.lstrip('/') |
|
|
|
|
|
return f"{base}/{path}" if path else base |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _find_calls(self, statements): |
|
|
"""Recursively find method and constructor calls inside Java AST nodes.""" |
|
|
|
|
|
def _recurse(node): |
|
|
if isinstance(node, T.MethodInvocation): |
|
|
if node.qualifier: |
|
|
self.called_entities.append(f"{node.qualifier}.{node.member}") |
|
|
else: |
|
|
self.called_entities.append(node.member) |
|
|
elif isinstance(node, T.ClassCreator): |
|
|
self.called_entities.append(self._walk_type(node.type)) |
|
|
|
|
|
|
|
|
if hasattr(node, '__dict__'): |
|
|
for attr, val in vars(node).items(): |
|
|
if isinstance(val, list): |
|
|
for child in val: |
|
|
if isinstance(child, T.Node): |
|
|
_recurse(child) |
|
|
elif isinstance(val, T.Node): |
|
|
_recurse(val) |
|
|
|
|
|
if not statements: |
|
|
return |
|
|
|
|
|
if isinstance(statements, list): |
|
|
for stmt in statements: |
|
|
_recurse(stmt) |
|
|
else: |
|
|
_recurse(statements) |
|
|
|
|
|
|
|
|
class JavaScriptEntityExtractor(BaseASTEntityExtractor): |
|
|
""" |
|
|
Extract declared and called entities from JavaScript code using esprima. |
|
|
Handles ES6+ syntax including classes, arrow functions, imports/exports. |
|
|
Also detects API endpoint calls (fetch, axios, etc.). |
|
|
""" |
|
|
|
|
|
|
|
|
HTTP_METHODS = {'get', 'post', 'put', 'patch', 'delete', 'head', 'options'} |
|
|
|
|
|
|
|
|
API_PATTERNS = { |
|
|
'fetch', |
|
|
'axios', |
|
|
'$http', |
|
|
'request', |
|
|
'superagent', |
|
|
} |
|
|
|
|
|
def __init__(self): |
|
|
self.reset() |
|
|
|
|
|
def reset(self) -> None: |
|
|
self.declared_entities: List[Dict[str, Any]] = [] |
|
|
self.called_entities: List[str] = [] |
|
|
self.scope_stack: List[str] = [] |
|
|
self.api_calls: List[Dict[str, Any]] = [] |
|
|
|
|
|
def _qualified(self, name: str) -> str: |
|
|
"""Return fully qualified name using current scope stack.""" |
|
|
if not name: |
|
|
return "" |
|
|
scope = ".".join(self.scope_stack) |
|
|
return f"{scope}.{name}" if scope else name |
|
|
|
|
|
def _get_function_name(self, node) -> Optional[str]: |
|
|
"""Extract function name from various function node types.""" |
|
|
if hasattr(node, 'id') and node.id: |
|
|
return node.id.name |
|
|
return None |
|
|
|
|
|
def _walk_node(self, node): |
|
|
"""Recursively walk the AST and extract entities.""" |
|
|
if not node or not hasattr(node, 'type'): |
|
|
return |
|
|
|
|
|
node_type = node.type |
|
|
|
|
|
|
|
|
if node_type == 'FunctionDeclaration': |
|
|
func_name = self._get_function_name(node) |
|
|
if func_name: |
|
|
qualified = self._qualified(func_name) |
|
|
self.declared_entities.append({"name": qualified, "type": "function"}) |
|
|
|
|
|
|
|
|
if hasattr(node, 'params'): |
|
|
for param in node.params: |
|
|
param_name = self._extract_pattern_name(param) |
|
|
if param_name: |
|
|
self.declared_entities.append({ |
|
|
"name": f"{qualified}.{param_name}", |
|
|
"type": "variable", |
|
|
"dtype": "unknown" |
|
|
}) |
|
|
|
|
|
self.scope_stack.append(func_name) |
|
|
if hasattr(node, 'body'): |
|
|
self._walk_node(node.body) |
|
|
self.scope_stack.pop() |
|
|
|
|
|
|
|
|
elif node_type == 'ArrowFunctionExpression': |
|
|
|
|
|
if hasattr(node, 'params'): |
|
|
for param in node.params: |
|
|
param_name = self._extract_pattern_name(param) |
|
|
|
|
|
if hasattr(node, 'body'): |
|
|
self._walk_node(node.body) |
|
|
|
|
|
|
|
|
elif node_type == 'FunctionExpression': |
|
|
func_name = self._get_function_name(node) |
|
|
if func_name: |
|
|
qualified = self._qualified(func_name) |
|
|
self.declared_entities.append({"name": qualified, "type": "function"}) |
|
|
self.scope_stack.append(func_name) |
|
|
|
|
|
if hasattr(node, 'params'): |
|
|
for param in node.params: |
|
|
param_name = self._extract_pattern_name(param) |
|
|
if param_name and func_name: |
|
|
self.declared_entities.append({ |
|
|
"name": f"{self._qualified(func_name)}.{param_name}", |
|
|
"type": "variable", |
|
|
"dtype": "unknown" |
|
|
}) |
|
|
|
|
|
if hasattr(node, 'body'): |
|
|
self._walk_node(node.body) |
|
|
|
|
|
if func_name: |
|
|
self.scope_stack.pop() |
|
|
|
|
|
|
|
|
elif node_type == 'ClassDeclaration': |
|
|
class_name = node.id.name if hasattr(node, 'id') and node.id else None |
|
|
if class_name: |
|
|
qualified = self._qualified(class_name) |
|
|
self.declared_entities.append({"name": qualified, "type": "class"}) |
|
|
|
|
|
|
|
|
if hasattr(node, 'superClass') and node.superClass: |
|
|
if hasattr(node.superClass, 'name'): |
|
|
self.called_entities.append(node.superClass.name) |
|
|
|
|
|
self.scope_stack.append(class_name) |
|
|
if hasattr(node, 'body') and hasattr(node.body, 'body'): |
|
|
for method in node.body.body: |
|
|
self._walk_node(method) |
|
|
self.scope_stack.pop() |
|
|
|
|
|
|
|
|
elif node_type == 'MethodDefinition': |
|
|
method_name = node.key.name if hasattr(node, 'key') and hasattr(node.key, 'name') else None |
|
|
if method_name: |
|
|
qualified = self._qualified(method_name) |
|
|
self.declared_entities.append({"name": qualified, "type": "method"}) |
|
|
|
|
|
if hasattr(node, 'value') and hasattr(node.value, 'params'): |
|
|
for param in node.value.params: |
|
|
param_name = self._extract_pattern_name(param) |
|
|
if param_name: |
|
|
self.declared_entities.append({ |
|
|
"name": f"{qualified}.{param_name}", |
|
|
"type": "variable", |
|
|
"dtype": "unknown" |
|
|
}) |
|
|
|
|
|
if hasattr(node, 'value'): |
|
|
self._walk_node(node.value) |
|
|
|
|
|
|
|
|
elif node_type == 'VariableDeclaration': |
|
|
if hasattr(node, 'declarations'): |
|
|
for decl in node.declarations: |
|
|
self._walk_node(decl) |
|
|
|
|
|
|
|
|
elif node_type == 'VariableDeclarator': |
|
|
var_name = self._extract_pattern_name(node.id) if hasattr(node, 'id') else None |
|
|
if var_name: |
|
|
qualified = self._qualified(var_name) |
|
|
|
|
|
|
|
|
if hasattr(node, 'init') and node.init: |
|
|
if node.init.type in ('FunctionExpression', 'ArrowFunctionExpression'): |
|
|
self.declared_entities.append({"name": qualified, "type": "function"}) |
|
|
self.scope_stack.append(var_name) |
|
|
self._walk_node(node.init) |
|
|
self.scope_stack.pop() |
|
|
else: |
|
|
self.declared_entities.append({ |
|
|
"name": qualified, |
|
|
"type": "variable", |
|
|
"dtype": "unknown" |
|
|
}) |
|
|
self._walk_node(node.init) |
|
|
else: |
|
|
self.declared_entities.append({ |
|
|
"name": qualified, |
|
|
"type": "variable", |
|
|
"dtype": "unknown" |
|
|
}) |
|
|
|
|
|
|
|
|
elif node_type == 'CallExpression': |
|
|
callee_name = self._extract_callee_name(node.callee) if hasattr(node, 'callee') else None |
|
|
if callee_name: |
|
|
self.called_entities.append(callee_name) |
|
|
|
|
|
|
|
|
self._detect_api_call(node, callee_name) |
|
|
|
|
|
|
|
|
if hasattr(node, 'arguments'): |
|
|
for arg in node.arguments: |
|
|
self._walk_node(arg) |
|
|
|
|
|
|
|
|
elif node_type == 'MemberExpression': |
|
|
|
|
|
if hasattr(node, 'object'): |
|
|
self._walk_node(node.object) |
|
|
if hasattr(node, 'property'): |
|
|
self._walk_node(node.property) |
|
|
|
|
|
|
|
|
elif node_type == 'ImportDeclaration': |
|
|
if hasattr(node, 'source') and hasattr(node.source, 'value'): |
|
|
self.called_entities.append(node.source.value) |
|
|
|
|
|
elif node_type == 'ExportNamedDeclaration': |
|
|
if hasattr(node, 'declaration'): |
|
|
self._walk_node(node.declaration) |
|
|
|
|
|
elif node_type == 'ExportDefaultDeclaration': |
|
|
if hasattr(node, 'declaration'): |
|
|
self._walk_node(node.declaration) |
|
|
|
|
|
|
|
|
else: |
|
|
if hasattr(node, '__dict__'): |
|
|
for attr, val in vars(node).items(): |
|
|
if isinstance(val, list): |
|
|
for item in val: |
|
|
if hasattr(item, 'type'): |
|
|
self._walk_node(item) |
|
|
elif hasattr(val, 'type'): |
|
|
self._walk_node(val) |
|
|
|
|
|
def _extract_pattern_name(self, pattern) -> Optional[str]: |
|
|
"""Extract name from various pattern types (Identifier, ObjectPattern, etc.).""" |
|
|
if not pattern: |
|
|
return None |
|
|
if hasattr(pattern, 'type'): |
|
|
if pattern.type == 'Identifier': |
|
|
return pattern.name if hasattr(pattern, 'name') else None |
|
|
elif pattern.type == 'RestElement': |
|
|
return self._extract_pattern_name(pattern.argument) if hasattr(pattern, 'argument') else None |
|
|
return None |
|
|
|
|
|
def _extract_callee_name(self, callee) -> Optional[str]: |
|
|
"""Extract the name of the function being called.""" |
|
|
if not callee: |
|
|
return None |
|
|
|
|
|
if hasattr(callee, 'type'): |
|
|
if callee.type == 'Identifier': |
|
|
return callee.name if hasattr(callee, 'name') else None |
|
|
elif callee.type == 'MemberExpression': |
|
|
obj = self._extract_callee_name(callee.object) if hasattr(callee, 'object') else "" |
|
|
prop = callee.property.name if hasattr(callee, 'property') and hasattr(callee.property, 'name') else "" |
|
|
if obj and prop: |
|
|
return f"{obj}.{prop}" |
|
|
return prop or obj |
|
|
return None |
|
|
|
|
|
def _detect_api_call(self, call_node, callee_name: str): |
|
|
""" |
|
|
Detect API endpoint calls in JavaScript code. |
|
|
Handles patterns like: |
|
|
- fetch('/api/users') |
|
|
- axios.get('/api/users') |
|
|
- axios.post('/api/users', data) |
|
|
- request.get('/api/users') |
|
|
""" |
|
|
if not callee_name or not hasattr(call_node, 'arguments'): |
|
|
return |
|
|
|
|
|
|
|
|
parts = callee_name.split('.') |
|
|
base = parts[0] |
|
|
method = parts[-1].lower() if len(parts) > 1 else None |
|
|
|
|
|
|
|
|
is_api_call = False |
|
|
http_method = 'unknown' |
|
|
|
|
|
|
|
|
if base == 'fetch': |
|
|
is_api_call = True |
|
|
http_method = 'GET' |
|
|
|
|
|
|
|
|
elif base in self.API_PATTERNS and method in self.HTTP_METHODS: |
|
|
is_api_call = True |
|
|
http_method = method.upper() |
|
|
|
|
|
|
|
|
elif base in self.API_PATTERNS and method is None: |
|
|
is_api_call = True |
|
|
http_method = 'GET' |
|
|
|
|
|
if not is_api_call: |
|
|
return |
|
|
|
|
|
|
|
|
if call_node.arguments: |
|
|
first_arg = call_node.arguments[0] |
|
|
endpoint = self._extract_string_literal(first_arg) |
|
|
|
|
|
if endpoint: |
|
|
|
|
|
self.called_entities.append(f"API:{http_method}:{endpoint}") |
|
|
|
|
|
|
|
|
self.api_calls.append({ |
|
|
"endpoint": endpoint, |
|
|
"method": http_method, |
|
|
"type": "api_call" |
|
|
}) |
|
|
|
|
|
def _extract_string_literal(self, node) -> Optional[str]: |
|
|
"""Extract string value from a Literal/TemplateLiteral node.""" |
|
|
if not node or not hasattr(node, 'type'): |
|
|
return None |
|
|
|
|
|
if node.type == 'Literal' and isinstance(node.value, str): |
|
|
return node.value |
|
|
elif node.type == 'TemplateLiteral': |
|
|
|
|
|
|
|
|
if hasattr(node, 'quasis'): |
|
|
parts = [] |
|
|
for i, quasi in enumerate(node.quasis): |
|
|
if hasattr(quasi, 'value') and hasattr(quasi.value, 'raw'): |
|
|
parts.append(quasi.value.raw) |
|
|
if i < len(node.quasis) - 1: |
|
|
parts.append('{param}') |
|
|
return ''.join(parts) |
|
|
|
|
|
return None |
|
|
|
|
|
def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]: |
|
|
self.reset() |
|
|
|
|
|
try: |
|
|
tree = esprima.parseScript(code, {'tolerant': True, 'loc': False}) |
|
|
except Exception as e: |
|
|
|
|
|
try: |
|
|
tree = esprima.parseModule(code, {'tolerant': True, 'loc': False}) |
|
|
except Exception as e2: |
|
|
logger.error(f"Failed to parse JavaScript code: {e2}") |
|
|
return [], [] |
|
|
|
|
|
if hasattr(tree, 'body'): |
|
|
for node in tree.body: |
|
|
self._walk_node(node) |
|
|
|
|
|
|
|
|
seen_decl = set() |
|
|
unique_declared = [] |
|
|
for e in self.declared_entities: |
|
|
key = (e.get("name"), e.get("type"), e.get("dtype")) |
|
|
if key not in seen_decl: |
|
|
unique_declared.append(e) |
|
|
seen_decl.add(key) |
|
|
|
|
|
unique_called = list(dict.fromkeys(self.called_entities)) |
|
|
return unique_declared, unique_called |
|
|
|
|
|
|
|
|
class CEntityExtractor(BaseASTEntityExtractor): |
|
|
""" |
|
|
Extract declared and called entities from C code using clang.cindex (libclang), |
|
|
with filtering to ignore system headers. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.index = cindex.Index.create() |
|
|
|
|
|
def reset(self) -> None: |
|
|
"""No persistent state to reset, but method provided for interface consistency.""" |
|
|
pass |
|
|
|
|
|
def _walk_cursor(self, cursor, declared, called, source_file): |
|
|
"""Recursively walk a clang Cursor, restricted to the main file.""" |
|
|
for c in cursor.get_children(): |
|
|
|
|
|
|
|
|
if c.kind == cindex.CursorKind.INCLUSION_DIRECTIVE: |
|
|
|
|
|
included_file = c.displayname |
|
|
if included_file: |
|
|
called.append(included_file) |
|
|
continue |
|
|
|
|
|
loc = c.location |
|
|
if not loc.file or not source_file: |
|
|
continue |
|
|
|
|
|
|
|
|
if os.path.abspath(loc.file.name) != os.path.abspath(source_file): |
|
|
continue |
|
|
|
|
|
|
|
|
if c.kind.is_declaration(): |
|
|
if c.kind in (cindex.CursorKind.FUNCTION_DECL, cindex.CursorKind.FUNCTION_TEMPLATE): |
|
|
name = c.spelling or c.displayname |
|
|
declared.append({"name": name, "type": "function"}) |
|
|
for p in c.get_arguments(): |
|
|
declared.append({ |
|
|
"name": f"{name}.{p.spelling}", |
|
|
"type": "variable", |
|
|
"dtype": p.type.spelling |
|
|
}) |
|
|
elif c.kind == cindex.CursorKind.VAR_DECL: |
|
|
declared.append({ |
|
|
"name": c.spelling, |
|
|
"type": "variable", |
|
|
"dtype": c.type.spelling |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
if c.type.spelling: |
|
|
|
|
|
type_name = c.type.spelling.strip() |
|
|
|
|
|
type_name = type_name.replace('const', '').replace('&', '').replace('*', '').replace('struct', '').strip() |
|
|
if type_name and not type_name in ['int', 'float', 'double', 'char', 'bool', 'void', 'long', 'short', 'unsigned', 'signed', 'size_t']: |
|
|
called.append(type_name) |
|
|
elif c.kind == cindex.CursorKind.STRUCT_DECL: |
|
|
declared.append({"name": c.spelling or c.displayname, "type": "struct"}) |
|
|
elif c.kind == cindex.CursorKind.TYPEDEF_DECL: |
|
|
declared.append({"name": c.spelling, "type": "typedef"}) |
|
|
|
|
|
|
|
|
if c.kind == cindex.CursorKind.CALL_EXPR: |
|
|
callee = None |
|
|
for child in c.get_children(): |
|
|
if child.kind in (cindex.CursorKind.DECL_REF_EXPR, cindex.CursorKind.MEMBER_REF_EXPR): |
|
|
callee = child.spelling |
|
|
break |
|
|
if callee: |
|
|
called.append(callee) |
|
|
else: |
|
|
called.append(c.displayname or c.spelling) |
|
|
|
|
|
|
|
|
self._walk_cursor(c, declared, called, source_file) |
|
|
|
|
|
def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]: |
|
|
declared, called = [], [] |
|
|
|
|
|
|
|
|
|
|
|
tf_name = None |
|
|
temp_file = False |
|
|
|
|
|
if file_path and os.path.exists(file_path): |
|
|
tf_name = file_path |
|
|
temp_file = False |
|
|
else: |
|
|
with tempfile.NamedTemporaryFile(suffix=".c", mode="w+", delete=False) as tf: |
|
|
tf_name = tf.name |
|
|
tf.write(code) |
|
|
tf.flush() |
|
|
temp_file = True |
|
|
|
|
|
|
|
|
include_dir = os.path.dirname(tf_name) if tf_name else None |
|
|
args = ['-std=c11'] |
|
|
if include_dir: |
|
|
args.append(f'-I{include_dir}') |
|
|
|
|
|
try: |
|
|
tu = self.index.parse( |
|
|
tf_name, |
|
|
args=args, |
|
|
options=cindex.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD |
|
|
) |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"libclang failed to parse translation unit: {e}") |
|
|
|
|
|
self._walk_cursor(tu.cursor, declared, called, tf_name) |
|
|
|
|
|
|
|
|
seen_decl = set() |
|
|
unique_declared = [] |
|
|
for e in declared: |
|
|
key = (e.get("name"), e.get("type"), e.get("dtype", None)) |
|
|
if key not in seen_decl: |
|
|
unique_declared.append(e) |
|
|
seen_decl.add(key) |
|
|
|
|
|
unique_called = list(dict.fromkeys(called)) |
|
|
|
|
|
|
|
|
if temp_file: |
|
|
try: |
|
|
os.unlink(tf_name) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
return unique_declared, unique_called |
|
|
|
|
|
|
|
|
class CppEntityExtractor(BaseASTEntityExtractor): |
|
|
""" |
|
|
Extract declared and called entities from C++ code using clang.cindex (libclang), |
|
|
including classes, namespaces, and methods. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.index = cindex.Index.create() |
|
|
self.reset() |
|
|
|
|
|
def reset(self) -> None: |
|
|
self.declared_entities = [] |
|
|
self.called_entities = [] |
|
|
self.scope_stack = [] |
|
|
|
|
|
def _qualified(self, name: str) -> str: |
|
|
"""Return fully qualified name using current scope stack.""" |
|
|
if not name: |
|
|
return "" |
|
|
if not self.scope_stack: |
|
|
return name |
|
|
return "::".join(self.scope_stack + [name]) |
|
|
|
|
|
def _walk_cursor(self, cursor, source_file: str): |
|
|
for c in cursor.get_children(): |
|
|
|
|
|
|
|
|
if c.kind == cindex.CursorKind.INCLUSION_DIRECTIVE: |
|
|
|
|
|
included_file = c.displayname |
|
|
if included_file: |
|
|
self.called_entities.append(included_file) |
|
|
continue |
|
|
|
|
|
kind = c.kind |
|
|
|
|
|
|
|
|
if kind == cindex.CursorKind.NAMESPACE: |
|
|
if c.spelling: |
|
|
self.scope_stack.append(c.spelling) |
|
|
self._walk_cursor(c, source_file) |
|
|
if c.spelling: |
|
|
self.scope_stack.pop() |
|
|
continue |
|
|
|
|
|
|
|
|
loc = c.location |
|
|
|
|
|
if loc.file and os.path.abspath(loc.file.name) != os.path.abspath(source_file): |
|
|
continue |
|
|
|
|
|
|
|
|
if kind in (cindex.CursorKind.CLASS_DECL, cindex.CursorKind.STRUCT_DECL): |
|
|
|
|
|
if c.spelling: |
|
|
|
|
|
is_def = c.is_definition() if hasattr(c, 'is_definition') else True |
|
|
if is_def: |
|
|
full_name = self._qualified(c.spelling) |
|
|
self.declared_entities.append({"name": full_name, "type": "class"}) |
|
|
|
|
|
|
|
|
for base in c.get_children(): |
|
|
if base.kind == cindex.CursorKind.CXX_BASE_SPECIFIER: |
|
|
if base.spelling: |
|
|
self.called_entities.append(base.spelling) |
|
|
|
|
|
self.scope_stack.append(c.spelling) |
|
|
self._walk_cursor(c, source_file) |
|
|
self.scope_stack.pop() |
|
|
continue |
|
|
|
|
|
|
|
|
if kind in (cindex.CursorKind.CXX_METHOD, cindex.CursorKind.CONSTRUCTOR, cindex.CursorKind.DESTRUCTOR): |
|
|
if c.spelling: |
|
|
full_name = self._qualified(c.spelling) |
|
|
self.declared_entities.append({"name": full_name, "type": "method"}) |
|
|
|
|
|
for p in c.get_arguments(): |
|
|
if p.spelling: |
|
|
self.declared_entities.append({ |
|
|
"name": f"{full_name}.{p.spelling}", |
|
|
"type": "variable", |
|
|
"dtype": p.type.spelling |
|
|
}) |
|
|
|
|
|
self._walk_cursor(c, source_file) |
|
|
continue |
|
|
|
|
|
|
|
|
if kind == cindex.CursorKind.FUNCTION_DECL: |
|
|
if c.spelling: |
|
|
full_name = self._qualified(c.spelling) |
|
|
self.declared_entities.append({"name": full_name, "type": "function"}) |
|
|
for p in c.get_arguments(): |
|
|
if p.spelling: |
|
|
self.declared_entities.append({ |
|
|
"name": f"{full_name}.{p.spelling}", |
|
|
"type": "variable", |
|
|
"dtype": p.type.spelling |
|
|
}) |
|
|
self._walk_cursor(c, source_file) |
|
|
continue |
|
|
|
|
|
|
|
|
if kind == cindex.CursorKind.VAR_DECL: |
|
|
full_name = self._qualified(c.spelling) |
|
|
self.declared_entities.append({ |
|
|
"name": full_name, |
|
|
"type": "variable", |
|
|
"dtype": c.type.spelling |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
type_ref_found = False |
|
|
for child in c.get_children(): |
|
|
if child.kind == cindex.CursorKind.TYPE_REF: |
|
|
|
|
|
|
|
|
if child.spelling: |
|
|
type_name = child.spelling.replace('class ', '').replace('struct ', '').strip() |
|
|
if type_name: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.called_entities.append(type_name) |
|
|
type_ref_found = True |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not type_ref_found and c.type.spelling: |
|
|
|
|
|
type_name = c.type.spelling.strip() |
|
|
|
|
|
type_name = type_name.replace('const', '').replace('&', '').replace('*', '').strip() |
|
|
if type_name and not type_name in ['int', 'float', 'double', 'char', 'bool', 'void', 'long', 'short', 'unsigned', 'signed']: |
|
|
|
|
|
|
|
|
|
|
|
self.called_entities.append(type_name) |
|
|
|
|
|
|
|
|
if kind == cindex.CursorKind.CALL_EXPR: |
|
|
callee = None |
|
|
for child in c.get_children(): |
|
|
if child.kind in (cindex.CursorKind.DECL_REF_EXPR, cindex.CursorKind.MEMBER_REF_EXPR): |
|
|
callee = child.spelling |
|
|
break |
|
|
if callee: |
|
|
self.called_entities.append(callee) |
|
|
else: |
|
|
self.called_entities.append(c.displayname or c.spelling) |
|
|
|
|
|
|
|
|
self._walk_cursor(c, source_file) |
|
|
|
|
|
def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]: |
|
|
self.reset() |
|
|
|
|
|
|
|
|
|
|
|
tf_name = None |
|
|
temp_file = False |
|
|
|
|
|
if file_path and os.path.exists(file_path): |
|
|
tf_name = file_path |
|
|
temp_file = False |
|
|
else: |
|
|
with tempfile.NamedTemporaryFile(suffix=".cpp", mode="w+", delete=False) as tf: |
|
|
tf_name = tf.name |
|
|
tf.write(code) |
|
|
tf.flush() |
|
|
temp_file = True |
|
|
|
|
|
|
|
|
include_dir = os.path.dirname(tf_name) if tf_name else None |
|
|
args = ['-std=c++17', '-xc++'] |
|
|
if include_dir: |
|
|
args.append(f'-I{include_dir}') |
|
|
|
|
|
try: |
|
|
tu = self.index.parse( |
|
|
tf_name, |
|
|
args=args, |
|
|
options=cindex.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD |
|
|
) |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"libclang failed to parse C++ translation unit: {e}") |
|
|
|
|
|
self._walk_cursor(tu.cursor, tf_name) |
|
|
|
|
|
|
|
|
seen_decl = set() |
|
|
unique_declared = [] |
|
|
for e in self.declared_entities: |
|
|
key = (e.get("name"), e.get("type"), e.get("dtype", None)) |
|
|
if key not in seen_decl: |
|
|
unique_declared.append(e) |
|
|
seen_decl.add(key) |
|
|
|
|
|
unique_called = list(dict.fromkeys(self.called_entities)) |
|
|
|
|
|
|
|
|
if temp_file: |
|
|
try: |
|
|
os.unlink(tf_name) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
return unique_declared, unique_called |
|
|
|
|
|
|
|
|
class RustEntityExtractor(BaseASTEntityExtractor): |
|
|
""" |
|
|
Extract declared and called entities from Rust code using tree-sitter. |
|
|
Handles structs, enums, traits, functions, methods, and modules. |
|
|
Also detects API endpoint definitions (Actix-web, Rocket, Axum, Warp). |
|
|
""" |
|
|
|
|
|
|
|
|
ROUTE_MACROS = { |
|
|
'get', 'post', 'put', 'patch', 'delete', 'head', 'options', |
|
|
'Get', 'Post', 'Put', 'Patch', 'Delete', 'Head', 'Options', |
|
|
} |
|
|
|
|
|
|
|
|
ROUTE_PATTERNS = { |
|
|
'route', |
|
|
'web::get', 'web::post', 'web::put', 'web::delete', |
|
|
} |
|
|
|
|
|
def __init__(self): |
|
|
|
|
|
self.parser = Parser() |
|
|
self.parser.language = Language(ts_rust.language()) |
|
|
self.reset() |
|
|
|
|
|
def reset(self) -> None: |
|
|
self.declared_entities = [] |
|
|
self.called_entities = [] |
|
|
self.scope_stack = [] |
|
|
self.api_endpoints: List[Dict[str, Any]] = [] |
|
|
|
|
|
def _qualified(self, name: str) -> str: |
|
|
"""Return fully qualified name using current scope stack.""" |
|
|
if not name: |
|
|
return "" |
|
|
if not self.scope_stack: |
|
|
return name |
|
|
return "::".join(self.scope_stack + [name]) |
|
|
|
|
|
def _get_node_text(self, node, code_bytes: bytes) -> str: |
|
|
"""Extract text content of a node.""" |
|
|
return code_bytes[node.start_byte:node.end_byte].decode('utf8') |
|
|
|
|
|
def _extract_api_endpoint_from_attributes(self, node, code_bytes: bytes) -> Optional[Dict[str, Any]]: |
|
|
""" |
|
|
Extract API endpoint information from Rust function attributes. |
|
|
Handles patterns like: |
|
|
- #[get("/users")] # Actix-web, Rocket |
|
|
- #[post("/users")] # Actix-web, Rocket |
|
|
- #[route("/users", method="GET")] # Generic route |
|
|
|
|
|
Note: In tree-sitter Rust AST, attributes appear as PREVIOUS SIBLINGS |
|
|
of the function_item node, not as children. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
parent = node.parent |
|
|
if not parent: |
|
|
return None |
|
|
|
|
|
|
|
|
node_index = None |
|
|
for i, child in enumerate(parent.children): |
|
|
if child == node: |
|
|
node_index = i |
|
|
break |
|
|
|
|
|
if node_index is None: |
|
|
return None |
|
|
|
|
|
|
|
|
for i in range(node_index - 1, -1, -1): |
|
|
sibling = parent.children[i] |
|
|
|
|
|
|
|
|
if sibling.type not in ['attribute_item', 'line_comment', 'block_comment']: |
|
|
break |
|
|
|
|
|
if sibling.type == 'attribute_item': |
|
|
attr_text = self._get_node_text(sibling, code_bytes) |
|
|
|
|
|
|
|
|
|
|
|
method_pattern = r'#\[(get|post|put|patch|delete|head|options)\s*\(\s*"([^"]+)"(?:\s*,.*?)?\s*\)\]' |
|
|
match = re.search(method_pattern, attr_text, re.IGNORECASE) |
|
|
|
|
|
if match: |
|
|
http_method = match.group(1).upper() |
|
|
endpoint_path = match.group(2) |
|
|
return { |
|
|
"endpoint": endpoint_path, |
|
|
"methods": [http_method], |
|
|
"type": "api_endpoint_definition" |
|
|
} |
|
|
|
|
|
|
|
|
route_pattern = r'#\[route\s*\(\s*"([^"]+)"(?:.*?method\s*=\s*"([^"]+)")?\s*\)\]' |
|
|
match = re.search(route_pattern, attr_text, re.IGNORECASE) |
|
|
|
|
|
if match: |
|
|
endpoint_path = match.group(1) |
|
|
http_method = match.group(2).upper() if match.group(2) else "GET" |
|
|
return { |
|
|
"endpoint": endpoint_path, |
|
|
"methods": [http_method], |
|
|
"type": "api_endpoint_definition" |
|
|
} |
|
|
|
|
|
return None |
|
|
|
|
|
def _walk_tree(self, node, code_bytes: bytes): |
|
|
"""Recursively walk the tree-sitter AST.""" |
|
|
node_type = node.type |
|
|
|
|
|
|
|
|
if node_type == 'mod_item': |
|
|
|
|
|
name_node = node.child_by_field_name('name') |
|
|
if name_node: |
|
|
mod_name = self._get_node_text(name_node, code_bytes) |
|
|
qualified = self._qualified(mod_name) |
|
|
self.declared_entities.append({"name": qualified, "type": "module"}) |
|
|
|
|
|
self.scope_stack.append(mod_name) |
|
|
body = node.child_by_field_name('body') |
|
|
if body: |
|
|
for child in body.children: |
|
|
self._walk_tree(child, code_bytes) |
|
|
self.scope_stack.pop() |
|
|
return |
|
|
|
|
|
|
|
|
elif node_type == 'struct_item': |
|
|
name_node = node.child_by_field_name('name') |
|
|
if name_node: |
|
|
struct_name = self._get_node_text(name_node, code_bytes) |
|
|
qualified = self._qualified(struct_name) |
|
|
self.declared_entities.append({"name": qualified, "type": "struct"}) |
|
|
|
|
|
|
|
|
type_params = node.child_by_field_name('type_parameters') |
|
|
if type_params: |
|
|
self._walk_tree(type_params, code_bytes) |
|
|
|
|
|
self.scope_stack.append(struct_name) |
|
|
|
|
|
body = node.child_by_field_name('body') |
|
|
if body: |
|
|
for child in body.children: |
|
|
if child.type == 'field_declaration': |
|
|
field_name_node = child.child_by_field_name('name') |
|
|
field_type_node = child.child_by_field_name('type') |
|
|
if field_name_node: |
|
|
field_name = self._get_node_text(field_name_node, code_bytes) |
|
|
field_type = self._get_node_text(field_type_node, code_bytes) if field_type_node else "unknown" |
|
|
self.declared_entities.append({ |
|
|
"name": f"{qualified}.{field_name}", |
|
|
"type": "field", |
|
|
"dtype": field_type |
|
|
}) |
|
|
self.scope_stack.pop() |
|
|
return |
|
|
|
|
|
|
|
|
elif node_type == 'enum_item': |
|
|
name_node = node.child_by_field_name('name') |
|
|
if name_node: |
|
|
enum_name = self._get_node_text(name_node, code_bytes) |
|
|
qualified = self._qualified(enum_name) |
|
|
self.declared_entities.append({"name": qualified, "type": "enum"}) |
|
|
|
|
|
self.scope_stack.append(enum_name) |
|
|
body = node.child_by_field_name('body') |
|
|
if body: |
|
|
for child in body.children: |
|
|
if child.type == 'enum_variant': |
|
|
variant_name_node = child.child_by_field_name('name') |
|
|
if variant_name_node: |
|
|
variant_name = self._get_node_text(variant_name_node, code_bytes) |
|
|
self.declared_entities.append({ |
|
|
"name": f"{qualified}::{variant_name}", |
|
|
"type": "enum_variant" |
|
|
}) |
|
|
self.scope_stack.pop() |
|
|
return |
|
|
|
|
|
|
|
|
elif node_type == 'trait_item': |
|
|
name_node = node.child_by_field_name('name') |
|
|
if name_node: |
|
|
trait_name = self._get_node_text(name_node, code_bytes) |
|
|
qualified = self._qualified(trait_name) |
|
|
self.declared_entities.append({"name": qualified, "type": "trait"}) |
|
|
|
|
|
self.scope_stack.append(trait_name) |
|
|
body = node.child_by_field_name('body') |
|
|
if body: |
|
|
for child in body.children: |
|
|
self._walk_tree(child, code_bytes) |
|
|
self.scope_stack.pop() |
|
|
return |
|
|
|
|
|
|
|
|
elif node_type == 'impl_item': |
|
|
|
|
|
type_node = node.child_by_field_name('type') |
|
|
trait_node = node.child_by_field_name('trait') |
|
|
|
|
|
impl_name = None |
|
|
if type_node: |
|
|
impl_name = self._get_node_text(type_node, code_bytes) |
|
|
|
|
|
if trait_node: |
|
|
trait_name = self._get_node_text(trait_node, code_bytes) |
|
|
self.called_entities.append(trait_name) |
|
|
|
|
|
if impl_name: |
|
|
self.scope_stack.append(impl_name) |
|
|
|
|
|
body = node.child_by_field_name('body') |
|
|
if body: |
|
|
for child in body.children: |
|
|
self._walk_tree(child, code_bytes) |
|
|
|
|
|
if impl_name: |
|
|
self.scope_stack.pop() |
|
|
return |
|
|
|
|
|
|
|
|
elif node_type == 'function_item': |
|
|
name_node = node.child_by_field_name('name') |
|
|
if name_node: |
|
|
func_name = self._get_node_text(name_node, code_bytes) |
|
|
qualified = self._qualified(func_name) |
|
|
|
|
|
|
|
|
api_info = self._extract_api_endpoint_from_attributes(node, code_bytes) |
|
|
|
|
|
if api_info: |
|
|
|
|
|
self.declared_entities.append({ |
|
|
"name": qualified, |
|
|
"type": "api_endpoint", |
|
|
"endpoint": api_info.get("endpoint"), |
|
|
"methods": api_info.get("methods") |
|
|
}) |
|
|
self.api_endpoints.append({**api_info, "function": qualified}) |
|
|
entity_type = "api_endpoint" |
|
|
else: |
|
|
|
|
|
entity_type = "method" if len(self.scope_stack) > 0 else "function" |
|
|
self.declared_entities.append({"name": qualified, "type": entity_type}) |
|
|
|
|
|
|
|
|
params = node.child_by_field_name('parameters') |
|
|
if params: |
|
|
for child in params.children: |
|
|
if child.type == 'parameter': |
|
|
pattern = child.child_by_field_name('pattern') |
|
|
type_node = child.child_by_field_name('type') |
|
|
if pattern: |
|
|
param_name = self._get_node_text(pattern, code_bytes) |
|
|
param_type = self._get_node_text(type_node, code_bytes) if type_node else "unknown" |
|
|
|
|
|
if param_name not in ['self', '&self', '&mut self', 'mut self']: |
|
|
self.declared_entities.append({ |
|
|
"name": f"{qualified}.{param_name}", |
|
|
"type": "variable", |
|
|
"dtype": param_type |
|
|
}) |
|
|
|
|
|
|
|
|
body = node.child_by_field_name('body') |
|
|
if body: |
|
|
self._walk_tree(body, code_bytes) |
|
|
return |
|
|
|
|
|
|
|
|
elif node_type == 'type_item': |
|
|
name_node = node.child_by_field_name('name') |
|
|
if name_node: |
|
|
type_name = self._get_node_text(name_node, code_bytes) |
|
|
qualified = self._qualified(type_name) |
|
|
self.declared_entities.append({"name": qualified, "type": "type_alias"}) |
|
|
return |
|
|
|
|
|
|
|
|
elif node_type == 'const_item': |
|
|
name_node = node.child_by_field_name('name') |
|
|
type_node = node.child_by_field_name('type') |
|
|
if name_node: |
|
|
const_name = self._get_node_text(name_node, code_bytes) |
|
|
const_type = self._get_node_text(type_node, code_bytes) if type_node else "unknown" |
|
|
qualified = self._qualified(const_name) |
|
|
self.declared_entities.append({ |
|
|
"name": qualified, |
|
|
"type": "constant", |
|
|
"dtype": const_type |
|
|
}) |
|
|
|
|
|
|
|
|
elif node_type == 'static_item': |
|
|
name_node = node.child_by_field_name('name') |
|
|
type_node = node.child_by_field_name('type') |
|
|
if name_node: |
|
|
static_name = self._get_node_text(name_node, code_bytes) |
|
|
static_type = self._get_node_text(type_node, code_bytes) if type_node else "unknown" |
|
|
qualified = self._qualified(static_name) |
|
|
self.declared_entities.append({ |
|
|
"name": qualified, |
|
|
"type": "static", |
|
|
"dtype": static_type |
|
|
}) |
|
|
|
|
|
|
|
|
elif node_type == 'let_declaration': |
|
|
pattern = node.child_by_field_name('pattern') |
|
|
type_node = node.child_by_field_name('type') |
|
|
if pattern and pattern.type == 'identifier': |
|
|
var_name = self._get_node_text(pattern, code_bytes) |
|
|
var_type = self._get_node_text(type_node, code_bytes) if type_node else "unknown" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif node_type == 'use_declaration': |
|
|
|
|
|
use_text = self._get_node_text(node, code_bytes) |
|
|
self.called_entities.append(use_text) |
|
|
|
|
|
|
|
|
elif node_type == 'call_expression': |
|
|
function = node.child_by_field_name('function') |
|
|
if function: |
|
|
func_text = self._get_node_text(function, code_bytes) |
|
|
|
|
|
|
|
|
self.called_entities.append(func_text) |
|
|
|
|
|
|
|
|
elif node_type == 'macro_invocation': |
|
|
macro_node = node.child_by_field_name('macro') |
|
|
if macro_node: |
|
|
macro_name = self._get_node_text(macro_node, code_bytes) |
|
|
self.called_entities.append(f"{macro_name}!") |
|
|
|
|
|
|
|
|
elif node_type == 'field_expression': |
|
|
field = node.child_by_field_name('field') |
|
|
if field: |
|
|
field_name = self._get_node_text(field, code_bytes) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for child in node.children: |
|
|
self._walk_tree(child, code_bytes) |
|
|
|
|
|
def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]: |
|
|
"""Extract entities from Rust code using tree-sitter.""" |
|
|
self.reset() |
|
|
|
|
|
code_bytes = code.encode('utf8') |
|
|
tree = self.parser.parse(code_bytes) |
|
|
|
|
|
|
|
|
self._walk_tree(tree.root_node, code_bytes) |
|
|
|
|
|
|
|
|
seen_decl = set() |
|
|
unique_declared = [] |
|
|
for e in self.declared_entities: |
|
|
key = (e.get("name"), e.get("type"), e.get("dtype", None)) |
|
|
if key not in seen_decl: |
|
|
unique_declared.append(e) |
|
|
seen_decl.add(key) |
|
|
|
|
|
unique_called = list(dict.fromkeys(self.called_entities)) |
|
|
|
|
|
return unique_declared, unique_called |
|
|
|
|
|
|
|
|
class PythonASTEntityExtractor(ast.NodeVisitor, BaseASTEntityExtractor): |
|
|
""" |
|
|
AST-based entity extractor for Python code. |
|
|
Also detects API endpoint definitions (FastAPI, Flask, Django REST Framework). |
|
|
""" |
|
|
|
|
|
|
|
|
API_DECORATORS = { |
|
|
'route', |
|
|
'get', 'post', 'put', 'patch', 'delete', 'head', 'options', |
|
|
'api_view', |
|
|
} |
|
|
|
|
|
def __init__(self): |
|
|
self.declared_entities: List[Dict[str, Any]] = [] |
|
|
self.called_entities: List[str] = [] |
|
|
self.current_class: Optional[str] = None |
|
|
self.current_function: Optional[str] = None |
|
|
self.api_endpoints: List[Dict[str, Any]] = [] |
|
|
|
|
|
def reset(self) -> None: |
|
|
"""Clear previous extraction state including context""" |
|
|
self.declared_entities = [] |
|
|
self.called_entities = [] |
|
|
self.current_class = None |
|
|
self.current_function = None |
|
|
self.api_endpoints = [] |
|
|
|
|
|
def _get_type_annotation(self, node: ast.AST) -> str: |
|
|
"""Extract type annotation from AST node""" |
|
|
if isinstance(node, ast.Name): |
|
|
return node.id |
|
|
elif isinstance(node, ast.Constant): |
|
|
return type(node.value).__name__ |
|
|
elif isinstance(node, ast.Attribute): |
|
|
return f"{self._get_type_annotation(node.value)}.{node.attr}" |
|
|
elif isinstance(node, ast.Subscript): |
|
|
|
|
|
base = self._get_type_annotation(node.value) |
|
|
if isinstance(node.slice, ast.Tuple): |
|
|
args = [self._get_type_annotation(elt) for elt in node.slice.elts] |
|
|
return f"{base}[{', '.join(args)}]" |
|
|
else: |
|
|
arg = self._get_type_annotation(node.slice) |
|
|
return f"{base}[{arg}]" |
|
|
return "unknown" |
|
|
|
|
|
def _infer_type_from_value(self, node: ast.AST) -> str: |
|
|
"""Infer type from assigned value""" |
|
|
if isinstance(node, ast.Constant): |
|
|
return type(node.value).__name__ |
|
|
elif isinstance(node, ast.List): |
|
|
return "list" |
|
|
elif isinstance(node, ast.Dict): |
|
|
return "dict" |
|
|
elif isinstance(node, ast.Set): |
|
|
return "set" |
|
|
elif isinstance(node, ast.Tuple): |
|
|
return "tuple" |
|
|
elif isinstance(node, ast.Call): |
|
|
if isinstance(node.func, ast.Name): |
|
|
return node.func.id |
|
|
elif isinstance(node.func, ast.Attribute): |
|
|
return "unknown" |
|
|
elif isinstance(node, ast.Name): |
|
|
return "unknown" |
|
|
return "unknown" |
|
|
|
|
|
def visit_ClassDef(self, node: ast.ClassDef): |
|
|
"""Visit class definitions""" |
|
|
old_class = self.current_class |
|
|
self.current_class = node.name |
|
|
|
|
|
|
|
|
self.declared_entities.append({ |
|
|
"name": node.name, |
|
|
"type": "class" |
|
|
}) |
|
|
|
|
|
|
|
|
for base in node.bases: |
|
|
if isinstance(base, ast.Name): |
|
|
self.called_entities.append(base.id) |
|
|
elif isinstance(base, ast.Attribute): |
|
|
self.called_entities.append(self._get_type_annotation(base)) |
|
|
|
|
|
|
|
|
self.generic_visit(node) |
|
|
self.current_class = old_class |
|
|
|
|
|
def visit_FunctionDef(self, node: ast.FunctionDef): |
|
|
"""Visit function/method definitions and detect API endpoints""" |
|
|
old_function = self.current_function |
|
|
|
|
|
if self.current_class: |
|
|
|
|
|
full_name = f"{self.current_class}.{node.name}" |
|
|
entity_type = "method" |
|
|
else: |
|
|
|
|
|
full_name = node.name |
|
|
entity_type = "function" |
|
|
|
|
|
self.current_function = full_name |
|
|
|
|
|
|
|
|
api_info = self._extract_api_endpoint_from_decorators(node.decorator_list, full_name) |
|
|
if api_info: |
|
|
|
|
|
self.declared_entities.append({ |
|
|
"name": full_name, |
|
|
"type": "api_endpoint", |
|
|
"endpoint": api_info.get("endpoint"), |
|
|
"methods": api_info.get("methods") |
|
|
}) |
|
|
self.api_endpoints.append(api_info) |
|
|
else: |
|
|
self.declared_entities.append({ |
|
|
"name": full_name, |
|
|
"type": entity_type |
|
|
}) |
|
|
|
|
|
|
|
|
for arg in node.args.args: |
|
|
if arg.arg == 'self' and self.current_class: |
|
|
continue |
|
|
|
|
|
dtype = "unknown" |
|
|
if arg.annotation: |
|
|
dtype = self._get_type_annotation(arg.annotation) |
|
|
|
|
|
param_name = f"{full_name}.{arg.arg}" if entity_type == "method" else arg.arg |
|
|
self.declared_entities.append({ |
|
|
"name": param_name, |
|
|
"type": "variable", |
|
|
"dtype": dtype |
|
|
}) |
|
|
|
|
|
|
|
|
self.generic_visit(node) |
|
|
self.current_function = old_function |
|
|
|
|
|
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): |
|
|
"""Visit async function/method definitions""" |
|
|
|
|
|
self.visit_FunctionDef(node) |
|
|
|
|
|
def visit_Assign(self, node: ast.Assign): |
|
|
"""Visit assignment statements""" |
|
|
|
|
|
dtype = self._infer_type_from_value(node.value) |
|
|
|
|
|
for target in node.targets: |
|
|
if isinstance(target, ast.Name): |
|
|
|
|
|
var_name = target.id |
|
|
if self.current_class and self.current_function and self.current_function.startswith(self.current_class): |
|
|
|
|
|
pass |
|
|
else: |
|
|
|
|
|
self.declared_entities.append({ |
|
|
"name": var_name, |
|
|
"type": "variable", |
|
|
"dtype": dtype |
|
|
}) |
|
|
|
|
|
elif isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name): |
|
|
|
|
|
if target.value.id == 'self' and self.current_class: |
|
|
attr_name = f"{self.current_class}.{target.attr}" |
|
|
self.declared_entities.append({ |
|
|
"name": attr_name, |
|
|
"type": "variable", |
|
|
"dtype": dtype |
|
|
}) |
|
|
|
|
|
|
|
|
self.generic_visit(node) |
|
|
|
|
|
def visit_AnnAssign(self, node: ast.AnnAssign): |
|
|
"""Visit annotated assignment statements (PEP 526)""" |
|
|
if isinstance(node.target, ast.Name): |
|
|
dtype = self._get_type_annotation(node.annotation) |
|
|
var_name = node.target.id |
|
|
|
|
|
self.declared_entities.append({ |
|
|
"name": var_name, |
|
|
"type": "variable", |
|
|
"dtype": dtype |
|
|
}) |
|
|
|
|
|
elif isinstance(node.target, ast.Attribute) and isinstance(node.target.value, ast.Name): |
|
|
if node.target.value.id == 'self' and self.current_class: |
|
|
dtype = self._get_type_annotation(node.annotation) |
|
|
attr_name = f"{self.current_class}.{node.target.attr}" |
|
|
self.declared_entities.append({ |
|
|
"name": attr_name, |
|
|
"type": "variable", |
|
|
"dtype": dtype |
|
|
}) |
|
|
|
|
|
|
|
|
if node.value: |
|
|
self.generic_visit(node) |
|
|
|
|
|
def visit_Import(self, node: ast.Import): |
|
|
"""Visit import statements""" |
|
|
for alias in node.names: |
|
|
|
|
|
self.called_entities.append(alias.name) |
|
|
self.generic_visit(node) |
|
|
|
|
|
def visit_ImportFrom(self, node: ast.ImportFrom): |
|
|
"""Visit from...import statements""" |
|
|
if node.module: |
|
|
|
|
|
self.called_entities.append(node.module) |
|
|
|
|
|
for alias in node.names: |
|
|
if alias.name != '*': |
|
|
self.called_entities.append(f"{node.module}.{alias.name}") |
|
|
else: |
|
|
|
|
|
for alias in node.names: |
|
|
if alias.name != '*': |
|
|
self.called_entities.append(alias.name) |
|
|
self.generic_visit(node) |
|
|
|
|
|
def visit_Call(self, node: ast.Call): |
|
|
"""Visit function/method calls""" |
|
|
if isinstance(node.func, ast.Name): |
|
|
|
|
|
self.called_entities.append(node.func.id) |
|
|
|
|
|
elif isinstance(node.func, ast.Attribute): |
|
|
|
|
|
if isinstance(node.func.value, ast.Name): |
|
|
|
|
|
|
|
|
method_name = node.func.attr |
|
|
|
|
|
obj_name = node.func.value.id |
|
|
obj_class = self._find_variable_type(obj_name) |
|
|
if obj_class and obj_class != "unknown": |
|
|
self.called_entities.append(f"{obj_class}.{method_name}") |
|
|
else: |
|
|
|
|
|
self.called_entities.append(method_name) |
|
|
|
|
|
elif isinstance(node.func.value, ast.Attribute): |
|
|
|
|
|
full_name = self._get_type_annotation(node.func) |
|
|
self.called_entities.append(full_name) |
|
|
|
|
|
|
|
|
self.generic_visit(node) |
|
|
|
|
|
def _find_variable_type(self, var_name: str) -> str: |
|
|
"""Find the type of a variable from declared entities""" |
|
|
for entity in self.declared_entities: |
|
|
if entity["name"] == var_name and entity["type"] == "variable": |
|
|
return entity.get("dtype", "unknown") |
|
|
return "unknown" |
|
|
|
|
|
def _extract_api_endpoint_from_decorators(self, decorators: List[ast.expr], function_name: str) -> Optional[Dict[str, Any]]: |
|
|
""" |
|
|
Extract API endpoint information from function decorators. |
|
|
Handles patterns like: |
|
|
- @app.route("/api/users", methods=["GET", "POST"]) # Flask |
|
|
- @app.get("/api/users") # FastAPI |
|
|
- @router.post("/api/users") # FastAPI with router |
|
|
- @api_view(['GET', 'POST']) # Django REST Framework |
|
|
""" |
|
|
for decorator in decorators: |
|
|
|
|
|
if isinstance(decorator, ast.Call): |
|
|
if isinstance(decorator.func, ast.Attribute): |
|
|
|
|
|
method_name = decorator.func.attr.lower() |
|
|
|
|
|
if method_name in self.API_DECORATORS: |
|
|
endpoint = None |
|
|
http_methods = [] |
|
|
|
|
|
|
|
|
if decorator.args and isinstance(decorator.args[0], ast.Constant): |
|
|
endpoint = decorator.args[0].value |
|
|
|
|
|
|
|
|
if method_name in {'get', 'post', 'put', 'patch', 'delete', 'head', 'options'}: |
|
|
http_methods = [method_name.upper()] |
|
|
|
|
|
|
|
|
elif method_name == 'route': |
|
|
for keyword in decorator.keywords: |
|
|
if keyword.arg == 'methods': |
|
|
if isinstance(keyword.value, ast.List): |
|
|
http_methods = [ |
|
|
elt.value for elt in keyword.value.elts |
|
|
if isinstance(elt, ast.Constant) and isinstance(elt.value, str) |
|
|
] |
|
|
if not http_methods: |
|
|
http_methods = ['GET'] |
|
|
|
|
|
|
|
|
elif method_name == 'api_view': |
|
|
if decorator.args and isinstance(decorator.args[0], ast.List): |
|
|
http_methods = [ |
|
|
elt.value for elt in decorator.args[0].elts |
|
|
if isinstance(elt, ast.Constant) and isinstance(elt.value, str) |
|
|
] |
|
|
|
|
|
if endpoint: |
|
|
return { |
|
|
"function": function_name, |
|
|
"endpoint": endpoint, |
|
|
"methods": http_methods, |
|
|
"type": "api_endpoint_definition" |
|
|
} |
|
|
|
|
|
return None |
|
|
|
|
|
def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]: |
|
|
""" |
|
|
Extract entities from Python code using AST parsing |
|
|
|
|
|
Args: |
|
|
code: Python source code as string |
|
|
file_path: Optional path to the source file (for context) |
|
|
|
|
|
Returns: |
|
|
Tuple of (declared_entities, called_entities) |
|
|
""" |
|
|
|
|
|
self.reset() |
|
|
|
|
|
try: |
|
|
tree = ast.parse(code) |
|
|
self.visit(tree) |
|
|
|
|
|
|
|
|
seen_declared = set() |
|
|
unique_declared = [] |
|
|
for entity in self.declared_entities: |
|
|
key = (entity["name"], entity["type"], entity.get("dtype")) |
|
|
if key not in seen_declared: |
|
|
unique_declared.append(entity) |
|
|
seen_declared.add(key) |
|
|
|
|
|
unique_called = list(dict.fromkeys(self.called_entities)) |
|
|
|
|
|
return unique_declared, unique_called |
|
|
|
|
|
except SyntaxError as e: |
|
|
logger.error(f"Syntax error in Python code: {e}") |
|
|
return [], [] |
|
|
except Exception as e: |
|
|
logger.error(f"Error parsing Python code: {e}", exc_info=True) |
|
|
return [], [] |
|
|
|
|
|
|
|
|
class HybridEntityExtractor: |
|
|
""" |
|
|
Hybrid entity extractor that uses AST for known languages, |
|
|
falls back to LLM for unknown ones |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.extractors = { |
|
|
'py': PythonASTEntityExtractor(), |
|
|
'c': CEntityExtractor(), |
|
|
'h': CppEntityExtractor(), |
|
|
'cpp': CppEntityExtractor(), |
|
|
'cc': CppEntityExtractor(), |
|
|
'cxx': CppEntityExtractor(), |
|
|
'hpp': CppEntityExtractor(), |
|
|
'hxx': CppEntityExtractor(), |
|
|
'hh': CppEntityExtractor(), |
|
|
'java': JavaEntityExtractor(), |
|
|
'js': JavaScriptEntityExtractor(), |
|
|
'jsx': JavaScriptEntityExtractor(), |
|
|
'ts': JavaScriptEntityExtractor(), |
|
|
'tsx': JavaScriptEntityExtractor(), |
|
|
'rs': RustEntityExtractor(), |
|
|
'html': HTMLEntityExtractor() |
|
|
} |
|
|
|
|
|
def _get_language_from_filename(self, file_name: str) -> str: |
|
|
ext = file_name.split('.')[-1].lower() |
|
|
return ext |
|
|
|
|
|
def extract_entities(self, code: str, file_name: str): |
|
|
|
|
|
lang = self._get_language_from_filename(file_name) |
|
|
extractor = self.extractors.get(lang) |
|
|
|
|
|
if extractor: |
|
|
|
|
|
try: |
|
|
extractor.reset() |
|
|
except Exception: |
|
|
|
|
|
pass |
|
|
|
|
|
logger.info(f"Using AST extraction for {lang.upper()} file: {file_name}") |
|
|
try: |
|
|
|
|
|
try: |
|
|
declared_entities, called_entities = extractor.extract_entities(code, file_path=file_name) |
|
|
except TypeError: |
|
|
|
|
|
declared_entities, called_entities = extractor.extract_entities(code) |
|
|
|
|
|
|
|
|
for entity in declared_entities: |
|
|
entity_name = entity.get('name', '') |
|
|
if entity_name: |
|
|
aliases = generate_entity_aliases(entity_name, file_name) |
|
|
entity['aliases'] = aliases |
|
|
logger.debug(f"Generated aliases for entity '{entity_name}': {aliases}") |
|
|
|
|
|
return declared_entities, called_entities |
|
|
except Exception as e: |
|
|
logger.error(f"Error during AST extraction for file {file_name}: {e}", exc_info=True) |
|
|
return [], [] |
|
|
else: |
|
|
raise Exception(f"Using LLM extraction for unsupported language: {file_name}") |
|
|
|