lailaelkoussy's picture
Add gradio_mcp_space and dependencies
3ec78dd
import ast
import os
import logging
import tempfile
from typing import List, Dict, Any, Tuple, Optional
from clang import cindex
import javalang
import javalang.tree as T
import esprima
from bs4 import BeautifulSoup
import tree_sitter_rust as ts_rust
from tree_sitter import Language, Parser
import re
from .utils.path_utils import generate_entity_aliases
LOGGER_NAME = "AST_ENTITY_EXTRACTOR"
logger = logging.getLogger(LOGGER_NAME)
class BaseASTEntityExtractor:
def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]:
"""
Extract entities from source code.
Args:
code: Source code as string
file_path: Optional path to the source file (for better context and include resolution)
Returns:
Tuple of (declared_entities, called_entities)
"""
raise NotImplementedError
# Add a reset contract so extractors can be reused safely
def reset(self) -> None:
"""
Reset internal state so the extractor instance can be reused.
Concrete extractors should override this to clear their buffers.
"""
raise NotImplementedError
class HTMLEntityExtractor(BaseASTEntityExtractor):
"""
Hybrid HTML AST-based entity extractor.
Responsibilities:
β€’ Parse HTML into a tree
β€’ Extract declared DOM entities (ids, names, classes)
β€’ Extract JavaScript calls from inline event handlers
β€’ Extract JS entities from <script> tags
β€’ Integrate cleanly with the hybrid AST graph linker
"""
EVENT_ATTR_PREFIX = "on" # e.g., onclick, onsubmit, etc.
def __init__(self):
self.js_extractor = JavaScriptEntityExtractor()
self.reset()
# --------------------------------------
# Core interface
# --------------------------------------
def reset(self):
self.declared_entities: List[Dict[str, str]] = []
self.called_entities: List[str] = []
def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, str]], List[str]]:
"""Main entry point: parse HTML and extract entities."""
self.reset()
try:
soup = BeautifulSoup(code, "html.parser")
except Exception as e:
print(f"[HTMLEntityExtractor] Parsing error: {e}")
return [], []
# --- DOM element declarations ---
for tag in soup.find_all(True):
self._handle_tag_declaration(tag)
self._handle_event_attributes(tag)
# --- <script> tags (inline + external) ---
for script in soup.find_all("script"):
self._handle_script(script)
# --- Deduplication ---
self.declared_entities = self._deduplicate_dicts(self.declared_entities)
self.called_entities = self._deduplicate_list(self.called_entities)
return self.declared_entities, self.called_entities
# --------------------------------------
# Tag & attribute handlers
# --------------------------------------
def _handle_tag_declaration(self, tag):
"""Extract declared DOM elements (id, name, class)."""
if tag.has_attr("id"):
self.declared_entities.append({"name": tag["id"], "type": "element"})
if tag.has_attr("name"):
self.declared_entities.append({"name": tag["name"], "type": "element"})
if tag.has_attr("class"):
classes = tag["class"]
if isinstance(classes, list):
for c in classes:
self.declared_entities.append({"name": c, "type": "class"})
elif isinstance(classes, str):
self.declared_entities.append({"name": classes, "type": "class"})
def _handle_event_attributes(self, tag):
"""Extract JS calls from inline event attributes."""
if not self.js_extractor:
return
for attr, value in tag.attrs.items():
if attr.lower().startswith(self.EVENT_ATTR_PREFIX) and isinstance(value, str):
try:
_, called = self.js_extractor.extract_entities(value)
self.called_entities.extend(called)
except Exception as e:
print(f"[HTMLEntityExtractor] JS parse error in {attr}: {e}")
def _handle_script(self, script):
"""Extract JS entities from <script> blocks or src attributes."""
if script.has_attr("src"):
src = script["src"]
self.called_entities.append(src)
return
if not self.js_extractor:
return
js_code = (script.string or "").strip()
if js_code:
try:
declared, called = self.js_extractor.extract_entities(js_code)
self.declared_entities.extend(declared)
self.called_entities.extend(called)
except Exception as e:
print(f"[HTMLEntityExtractor] JS parse error in <script>: {e}")
# --------------------------------------
# Helpers
# --------------------------------------
@staticmethod
def _deduplicate_dicts(dicts: List[Dict]) -> List[Dict]:
seen = set()
result = []
for d in dicts:
key = tuple(sorted(d.items()))
if key not in seen:
seen.add(key)
result.append(d)
return result
@staticmethod
def _deduplicate_list(items: List[str]) -> List[str]:
seen = set()
result = []
for i in items:
if i not in seen:
seen.add(i)
result.append(i)
return result
class JavaEntityExtractor(BaseASTEntityExtractor):
"""
Extract declared and called entities from Java code using javalang.
Produces the same (declared_entities, called_entities) structure as other extractors.
"""
def __init__(self):
self.reset()
def reset(self) -> None:
self.declared_entities: List[Dict[str, Any]] = []
self.called_entities: List[str] = []
self.current_package: Optional[str] = None
self.scope_stack: List[str] = []
self.api_endpoints: List[Dict[str, Any]] = [] # Track API endpoint definitions
self.current_class_base_path: Optional[str] = None # For @RequestMapping on class
# -----------------------------------------------------------
# Helpers
# -----------------------------------------------------------
def _qualified(self, name: str) -> str:
if not name:
return ""
scope = "::".join(self.scope_stack)
return f"{scope}::{name}" if scope else name
def _walk_type(self, t):
"""Return string representation of a type node."""
if not t:
return "unknown"
if isinstance(t, str):
return t
if hasattr(t, "name"):
name = t.name
if getattr(t, "arguments", None):
args = [self._walk_type(a.type) for a in t.arguments if hasattr(a, "type")]
name += "<" + ", ".join(args) + ">"
return name
return "unknown"
# -----------------------------------------------------------
# Main AST traversal
# -----------------------------------------------------------
def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]:
self.reset()
try:
tree = javalang.parse.parse(code)
except javalang.parser.JavaSyntaxError as e:
logger.error(f"Syntax error in Java code: {e}")
return [], []
except Exception as e:
logger.error(f"Error parsing Java code: {e}", exc_info=True)
return [], []
# --- Package ---
if tree.package:
self.current_package = tree.package.name
# --- Imports ---
for imp in tree.imports:
self.called_entities.append(imp.path)
# --- Types (classes, interfaces, enums) ---
for type_decl in tree.types:
self._visit_type(type_decl)
# Deduplicate
seen_decl = set()
unique_declared = []
for e in self.declared_entities:
key = (e.get("name"), e.get("type"), e.get("dtype"))
if key not in seen_decl:
unique_declared.append(e)
seen_decl.add(key)
unique_called = list(dict.fromkeys(self.called_entities))
return unique_declared, unique_called
# -----------------------------------------------------------
# Visitors for different node types
# -----------------------------------------------------------
def _visit_type(self, node):
if isinstance(node, javalang.tree.ClassDeclaration):
self._visit_class(node)
elif isinstance(node, javalang.tree.InterfaceDeclaration):
self._visit_interface(node)
elif isinstance(node, javalang.tree.EnumDeclaration):
self._visit_enum(node)
def _visit_class(self, node):
full_name = node.name
if self.current_package:
full_name = f"{self.current_package}.{node.name}"
qualified = self._qualified(full_name)
self.declared_entities.append({"name": qualified, "type": "class"})
# Check for REST controller annotations and extract base path
old_base_path = self.current_class_base_path
if node.annotations:
for annotation in node.annotations:
if annotation.name in {'RestController', 'Controller'}:
# Mark as REST controller
pass
elif annotation.name == 'RequestMapping':
# Extract base path from class-level @RequestMapping
self.current_class_base_path = self._extract_path_from_annotation(annotation)
# Inheritance
if node.extends:
self.called_entities.append(self._walk_type(node.extends))
for impl in node.implements or []:
self.called_entities.append(self._walk_type(impl))
self.scope_stack.append(full_name)
for member in node.body:
self._visit_member(member)
self.scope_stack.pop()
# Restore the previous base path
self.current_class_base_path = old_base_path
def _visit_interface(self, node):
full_name = node.name
if self.current_package:
full_name = f"{self.current_package}.{node.name}"
qualified = self._qualified(full_name)
self.declared_entities.append({"name": qualified, "type": "interface"})
for impl in node.extends or []:
self.called_entities.append(self._walk_type(impl))
self.scope_stack.append(full_name)
for member in node.body:
self._visit_member(member)
self.scope_stack.pop()
def _visit_enum(self, node):
full_name = node.name
if self.current_package:
full_name = f"{self.current_package}.{node.name}"
qualified = self._qualified(full_name)
self.declared_entities.append({"name": qualified, "type": "enum"})
def _visit_member(self, node):
# --- Method ---
if isinstance(node, T.MethodDeclaration):
method_name = self._qualified(node.name)
# Check for API endpoint annotations
api_info = self._extract_api_endpoint_from_annotations(node)
if api_info:
self.declared_entities.append({
"name": method_name,
"type": "api_endpoint",
"endpoint": api_info.get("endpoint"),
"methods": api_info.get("methods")
})
self.api_endpoints.append({**api_info, "function": method_name})
else:
self.declared_entities.append({"name": method_name, "type": "method"})
for param in node.parameters:
ptype = self._walk_type(param.type)
pname = f"{method_name}.{param.name}"
self.declared_entities.append({
"name": pname,
"type": "variable",
"dtype": ptype
})
# Look for method calls in the body
if node.body:
self._find_calls(node.body)
# --- Constructor ---
elif isinstance(node, T.ConstructorDeclaration):
ctor_name = self._qualified(node.name)
self.declared_entities.append({"name": ctor_name, "type": "constructor"})
for param in node.parameters:
ptype = self._walk_type(param.type)
pname = f"{ctor_name}.{param.name}"
self.declared_entities.append({
"name": pname,
"type": "variable",
"dtype": ptype
})
if node.body:
self._find_calls(node.body)
# --- Field ---
elif isinstance(node, T.FieldDeclaration):
dtype = self._walk_type(node.type)
for decl in node.declarators:
var_name = self._qualified(decl.name)
self.declared_entities.append({
"name": var_name,
"type": "variable",
"dtype": dtype
})
# --- Nested class/interface ---
elif isinstance(node, (T.ClassDeclaration, T.InterfaceDeclaration)):
self._visit_type(node)
# -----------------------------------------------------------
# API Endpoint Detection
# -----------------------------------------------------------
def _extract_api_endpoint_from_annotations(self, method) -> Optional[Dict[str, Any]]:
"""
Extract API endpoint information from Spring Boot method annotations.
Handles: @GetMapping, @PostMapping, @RequestMapping, etc.
"""
if not method.annotations:
return None
for annotation in method.annotations:
annotation_name = annotation.name
if annotation_name in {'GetMapping', 'PostMapping', 'PutMapping', 'PatchMapping', 'DeleteMapping'}:
# Extract HTTP method from annotation name
http_method = annotation_name.replace('Mapping', '').upper()
path = self._extract_path_from_annotation(annotation)
if path:
# Combine with class-level base path if present
full_path = self._combine_paths(self.current_class_base_path, path)
return {
"endpoint": full_path,
"methods": [http_method],
"type": "api_endpoint_definition"
}
elif annotation_name == 'RequestMapping':
# @RequestMapping can specify multiple methods
path = self._extract_path_from_annotation(annotation)
methods = self._extract_methods_from_annotation(annotation)
if path:
full_path = self._combine_paths(self.current_class_base_path, path)
return {
"endpoint": full_path,
"methods": methods if methods else ['GET'], # Default to GET
"type": "api_endpoint_definition"
}
return None
def _extract_path_from_annotation(self, annotation) -> Optional[str]:
"""Extract path/value from Spring annotation."""
if not annotation.element:
return None
# Handle @GetMapping("/path") - single value
if isinstance(annotation.element, T.Literal):
return annotation.element.value.strip('"')
# Handle @RequestMapping(value = "/path") or @RequestMapping(path = "/path")
if isinstance(annotation.element, list):
for elem in annotation.element:
if isinstance(elem, T.ElementValuePair):
if elem.name in {'value', 'path'}:
if isinstance(elem.value, T.Literal):
return elem.value.value.strip('"')
elif isinstance(elem.value, T.ElementArrayValue):
# Handle array: value = {"/path1", "/path2"}
if elem.value.values:
first_val = elem.value.values[0]
if isinstance(first_val, T.Literal):
return first_val.value.strip('"')
return None
def _extract_methods_from_annotation(self, annotation) -> List[str]:
"""Extract HTTP methods from @RequestMapping annotation."""
methods = []
if isinstance(annotation.element, list):
for elem in annotation.element:
if isinstance(elem, T.ElementValuePair):
if elem.name == 'method':
# Handle method = RequestMethod.GET or method = {RequestMethod.GET, RequestMethod.POST}
if hasattr(elem.value, 'member'):
# Single method: RequestMethod.GET
methods.append(elem.value.member)
elif isinstance(elem.value, T.ElementArrayValue):
# Multiple methods: {RequestMethod.GET, RequestMethod.POST}
for val in elem.value.values:
if hasattr(val, 'member'):
methods.append(val.member)
return methods
def _combine_paths(self, base_path: Optional[str], path: str) -> str:
"""Combine base path from class annotation with method path."""
if not base_path:
return path
# Normalize paths
base = base_path.rstrip('/')
path = path.lstrip('/')
return f"{base}/{path}" if path else base
# -----------------------------------------------------------
# Find method invocations
# -----------------------------------------------------------
def _find_calls(self, statements):
"""Recursively find method and constructor calls inside Java AST nodes."""
def _recurse(node):
if isinstance(node, T.MethodInvocation):
if node.qualifier:
self.called_entities.append(f"{node.qualifier}.{node.member}")
else:
self.called_entities.append(node.member)
elif isinstance(node, T.ClassCreator):
self.called_entities.append(self._walk_type(node.type))
# Recurse into all children
if hasattr(node, '__dict__'):
for attr, val in vars(node).items():
if isinstance(val, list):
for child in val:
if isinstance(child, T.Node):
_recurse(child)
elif isinstance(val, T.Node):
_recurse(val)
if not statements:
return
if isinstance(statements, list):
for stmt in statements:
_recurse(stmt)
else:
_recurse(statements)
class JavaScriptEntityExtractor(BaseASTEntityExtractor):
"""
Extract declared and called entities from JavaScript code using esprima.
Handles ES6+ syntax including classes, arrow functions, imports/exports.
Also detects API endpoint calls (fetch, axios, etc.).
"""
# Common HTTP methods to detect
HTTP_METHODS = {'get', 'post', 'put', 'patch', 'delete', 'head', 'options'}
# API call patterns to detect
API_PATTERNS = {
'fetch', # fetch('/api/users')
'axios', # axios.get('/api/users')
'$http', # Angular $http
'request', # request library
'superagent', # superagent library
}
def __init__(self):
self.reset()
def reset(self) -> None:
self.declared_entities: List[Dict[str, Any]] = []
self.called_entities: List[str] = []
self.scope_stack: List[str] = []
self.api_calls: List[Dict[str, Any]] = [] # Track API endpoint calls
def _qualified(self, name: str) -> str:
"""Return fully qualified name using current scope stack."""
if not name:
return ""
scope = ".".join(self.scope_stack)
return f"{scope}.{name}" if scope else name
def _get_function_name(self, node) -> Optional[str]:
"""Extract function name from various function node types."""
if hasattr(node, 'id') and node.id:
return node.id.name
return None
def _walk_node(self, node):
"""Recursively walk the AST and extract entities."""
if not node or not hasattr(node, 'type'):
return
node_type = node.type
# --- Function Declaration ---
if node_type == 'FunctionDeclaration':
func_name = self._get_function_name(node)
if func_name:
qualified = self._qualified(func_name)
self.declared_entities.append({"name": qualified, "type": "function"})
# Extract parameters
if hasattr(node, 'params'):
for param in node.params:
param_name = self._extract_pattern_name(param)
if param_name:
self.declared_entities.append({
"name": f"{qualified}.{param_name}",
"type": "variable",
"dtype": "unknown"
})
self.scope_stack.append(func_name)
if hasattr(node, 'body'):
self._walk_node(node.body)
self.scope_stack.pop()
# --- Arrow Function Expression ---
elif node_type == 'ArrowFunctionExpression':
# Arrow functions are typically assigned, handle in VariableDeclarator
if hasattr(node, 'params'):
for param in node.params:
param_name = self._extract_pattern_name(param)
# Note: can't fully qualify without parent context
if hasattr(node, 'body'):
self._walk_node(node.body)
# --- Function Expression ---
elif node_type == 'FunctionExpression':
func_name = self._get_function_name(node)
if func_name:
qualified = self._qualified(func_name)
self.declared_entities.append({"name": qualified, "type": "function"})
self.scope_stack.append(func_name)
if hasattr(node, 'params'):
for param in node.params:
param_name = self._extract_pattern_name(param)
if param_name and func_name:
self.declared_entities.append({
"name": f"{self._qualified(func_name)}.{param_name}",
"type": "variable",
"dtype": "unknown"
})
if hasattr(node, 'body'):
self._walk_node(node.body)
if func_name:
self.scope_stack.pop()
# --- Class Declaration ---
elif node_type == 'ClassDeclaration':
class_name = node.id.name if hasattr(node, 'id') and node.id else None
if class_name:
qualified = self._qualified(class_name)
self.declared_entities.append({"name": qualified, "type": "class"})
# Handle inheritance
if hasattr(node, 'superClass') and node.superClass:
if hasattr(node.superClass, 'name'):
self.called_entities.append(node.superClass.name)
self.scope_stack.append(class_name)
if hasattr(node, 'body') and hasattr(node.body, 'body'):
for method in node.body.body:
self._walk_node(method)
self.scope_stack.pop()
# --- Method Definition ---
elif node_type == 'MethodDefinition':
method_name = node.key.name if hasattr(node, 'key') and hasattr(node.key, 'name') else None
if method_name:
qualified = self._qualified(method_name)
self.declared_entities.append({"name": qualified, "type": "method"})
if hasattr(node, 'value') and hasattr(node.value, 'params'):
for param in node.value.params:
param_name = self._extract_pattern_name(param)
if param_name:
self.declared_entities.append({
"name": f"{qualified}.{param_name}",
"type": "variable",
"dtype": "unknown"
})
if hasattr(node, 'value'):
self._walk_node(node.value)
# --- Variable Declaration ---
elif node_type == 'VariableDeclaration':
if hasattr(node, 'declarations'):
for decl in node.declarations:
self._walk_node(decl)
# --- Variable Declarator ---
elif node_type == 'VariableDeclarator':
var_name = self._extract_pattern_name(node.id) if hasattr(node, 'id') else None
if var_name:
qualified = self._qualified(var_name)
# Check if it's a function assignment
if hasattr(node, 'init') and node.init:
if node.init.type in ('FunctionExpression', 'ArrowFunctionExpression'):
self.declared_entities.append({"name": qualified, "type": "function"})
self.scope_stack.append(var_name)
self._walk_node(node.init)
self.scope_stack.pop()
else:
self.declared_entities.append({
"name": qualified,
"type": "variable",
"dtype": "unknown"
})
self._walk_node(node.init)
else:
self.declared_entities.append({
"name": qualified,
"type": "variable",
"dtype": "unknown"
})
# --- Call Expression ---
elif node_type == 'CallExpression':
callee_name = self._extract_callee_name(node.callee) if hasattr(node, 'callee') else None
if callee_name:
self.called_entities.append(callee_name)
# Detect API endpoint calls
self._detect_api_call(node, callee_name)
# Walk arguments
if hasattr(node, 'arguments'):
for arg in node.arguments:
self._walk_node(arg)
# --- Member Expression ---
elif node_type == 'MemberExpression':
# Don't record as call, just traverse
if hasattr(node, 'object'):
self._walk_node(node.object)
if hasattr(node, 'property'):
self._walk_node(node.property)
# --- Import/Export ---
elif node_type == 'ImportDeclaration':
if hasattr(node, 'source') and hasattr(node.source, 'value'):
self.called_entities.append(node.source.value)
elif node_type == 'ExportNamedDeclaration':
if hasattr(node, 'declaration'):
self._walk_node(node.declaration)
elif node_type == 'ExportDefaultDeclaration':
if hasattr(node, 'declaration'):
self._walk_node(node.declaration)
# --- Recursive traversal for other nodes ---
else:
if hasattr(node, '__dict__'):
for attr, val in vars(node).items():
if isinstance(val, list):
for item in val:
if hasattr(item, 'type'):
self._walk_node(item)
elif hasattr(val, 'type'):
self._walk_node(val)
def _extract_pattern_name(self, pattern) -> Optional[str]:
"""Extract name from various pattern types (Identifier, ObjectPattern, etc.)."""
if not pattern:
return None
if hasattr(pattern, 'type'):
if pattern.type == 'Identifier':
return pattern.name if hasattr(pattern, 'name') else None
elif pattern.type == 'RestElement':
return self._extract_pattern_name(pattern.argument) if hasattr(pattern, 'argument') else None
return None
def _extract_callee_name(self, callee) -> Optional[str]:
"""Extract the name of the function being called."""
if not callee:
return None
if hasattr(callee, 'type'):
if callee.type == 'Identifier':
return callee.name if hasattr(callee, 'name') else None
elif callee.type == 'MemberExpression':
obj = self._extract_callee_name(callee.object) if hasattr(callee, 'object') else ""
prop = callee.property.name if hasattr(callee, 'property') and hasattr(callee.property, 'name') else ""
if obj and prop:
return f"{obj}.{prop}"
return prop or obj
return None
def _detect_api_call(self, call_node, callee_name: str):
"""
Detect API endpoint calls in JavaScript code.
Handles patterns like:
- fetch('/api/users')
- axios.get('/api/users')
- axios.post('/api/users', data)
- request.get('/api/users')
"""
if not callee_name or not hasattr(call_node, 'arguments'):
return
# Split callee name to check for patterns
parts = callee_name.split('.')
base = parts[0]
method = parts[-1].lower() if len(parts) > 1 else None
# Check if this is an API call
is_api_call = False
http_method = 'unknown'
# Pattern 1: fetch('/api/...')
if base == 'fetch':
is_api_call = True
http_method = 'GET' # Default for fetch
# Pattern 2: axios.get('/api/...'), request.post(...), etc.
elif base in self.API_PATTERNS and method in self.HTTP_METHODS:
is_api_call = True
http_method = method.upper()
# Pattern 3: axios('/api/...', {method: 'POST'})
elif base in self.API_PATTERNS and method is None:
is_api_call = True
http_method = 'GET' # Default
if not is_api_call:
return
# Extract the endpoint URL from arguments
if call_node.arguments:
first_arg = call_node.arguments[0]
endpoint = self._extract_string_literal(first_arg)
if endpoint:
# Store as a called entity with special type
self.called_entities.append(f"API:{http_method}:{endpoint}")
# Also track in api_calls for easier filtering
self.api_calls.append({
"endpoint": endpoint,
"method": http_method,
"type": "api_call"
})
def _extract_string_literal(self, node) -> Optional[str]:
"""Extract string value from a Literal/TemplateLiteral node."""
if not node or not hasattr(node, 'type'):
return None
if node.type == 'Literal' and isinstance(node.value, str):
return node.value
elif node.type == 'TemplateLiteral':
# For template literals, we try to extract the quasi parts
# e.g., `/api/${version}/users` -> /api/{version}/users
if hasattr(node, 'quasis'):
parts = []
for i, quasi in enumerate(node.quasis):
if hasattr(quasi, 'value') and hasattr(quasi.value, 'raw'):
parts.append(quasi.value.raw)
if i < len(node.quasis) - 1:
parts.append('{param}')
return ''.join(parts)
return None
def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]:
self.reset()
try:
tree = esprima.parseScript(code, {'tolerant': True, 'loc': False})
except Exception as e:
# Try parsing as module if script fails
try:
tree = esprima.parseModule(code, {'tolerant': True, 'loc': False})
except Exception as e2:
logger.error(f"Failed to parse JavaScript code: {e2}")
return [], []
if hasattr(tree, 'body'):
for node in tree.body:
self._walk_node(node)
# Deduplicate
seen_decl = set()
unique_declared = []
for e in self.declared_entities:
key = (e.get("name"), e.get("type"), e.get("dtype"))
if key not in seen_decl:
unique_declared.append(e)
seen_decl.add(key)
unique_called = list(dict.fromkeys(self.called_entities))
return unique_declared, unique_called
class CEntityExtractor(BaseASTEntityExtractor):
"""
Extract declared and called entities from C code using clang.cindex (libclang),
with filtering to ignore system headers.
"""
def __init__(self):
self.index = cindex.Index.create()
def reset(self) -> None:
"""No persistent state to reset, but method provided for interface consistency."""
pass
def _walk_cursor(self, cursor, declared, called, source_file):
"""Recursively walk a clang Cursor, restricted to the main file."""
for c in cursor.get_children():
# --- Include directives ---
# Note: INCLUSION_DIRECTIVE nodes are at the root level and need special handling
if c.kind == cindex.CursorKind.INCLUSION_DIRECTIVE:
# Get the included file name
included_file = c.displayname
if included_file:
called.append(included_file)
continue
loc = c.location
if not loc.file or not source_file:
continue
# Skip system / external headers for other nodes
if os.path.abspath(loc.file.name) != os.path.abspath(source_file):
continue
# --- Declarations ---
if c.kind.is_declaration():
if c.kind in (cindex.CursorKind.FUNCTION_DECL, cindex.CursorKind.FUNCTION_TEMPLATE):
name = c.spelling or c.displayname
declared.append({"name": name, "type": "function"})
for p in c.get_arguments():
declared.append({
"name": f"{name}.{p.spelling}",
"type": "variable",
"dtype": p.type.spelling
})
elif c.kind == cindex.CursorKind.VAR_DECL:
declared.append({
"name": c.spelling,
"type": "variable",
"dtype": c.type.spelling
})
# Add the variable's type to called entities
# This captures struct references like "struct Point p;"
if c.type.spelling:
# Extract the base type name (remove const, &, *, struct keyword, etc.)
type_name = c.type.spelling.strip()
# Remove common qualifiers and keywords
type_name = type_name.replace('const', '').replace('&', '').replace('*', '').replace('struct', '').strip()
if type_name and not type_name in ['int', 'float', 'double', 'char', 'bool', 'void', 'long', 'short', 'unsigned', 'signed', 'size_t']:
called.append(type_name)
elif c.kind == cindex.CursorKind.STRUCT_DECL:
declared.append({"name": c.spelling or c.displayname, "type": "struct"})
elif c.kind == cindex.CursorKind.TYPEDEF_DECL:
declared.append({"name": c.spelling, "type": "typedef"})
# --- Calls ---
if c.kind == cindex.CursorKind.CALL_EXPR:
callee = None
for child in c.get_children():
if child.kind in (cindex.CursorKind.DECL_REF_EXPR, cindex.CursorKind.MEMBER_REF_EXPR):
callee = child.spelling
break
if callee:
called.append(callee)
else:
called.append(c.displayname or c.spelling)
# --- Recurse ---
self._walk_cursor(c, declared, called, source_file)
def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]:
declared, called = [], []
# If file_path is provided, use it directly for better include resolution
# Otherwise, create a temporary file
tf_name = None
temp_file = False
if file_path and os.path.exists(file_path):
tf_name = file_path
temp_file = False
else:
with tempfile.NamedTemporaryFile(suffix=".c", mode="w+", delete=False) as tf:
tf_name = tf.name
tf.write(code)
tf.flush()
temp_file = True
# Get the directory containing the file for include paths
include_dir = os.path.dirname(tf_name) if tf_name else None
args = ['-std=c11']
if include_dir:
args.append(f'-I{include_dir}')
try:
tu = self.index.parse(
tf_name,
args=args,
options=cindex.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD
)
except Exception as e:
raise RuntimeError(f"libclang failed to parse translation unit: {e}")
self._walk_cursor(tu.cursor, declared, called, tf_name)
# Deduplicate
seen_decl = set()
unique_declared = []
for e in declared:
key = (e.get("name"), e.get("type"), e.get("dtype", None))
if key not in seen_decl:
unique_declared.append(e)
seen_decl.add(key)
unique_called = list(dict.fromkeys(called))
# Only delete if we created a temp file
if temp_file:
try:
os.unlink(tf_name)
except Exception:
pass
return unique_declared, unique_called
class CppEntityExtractor(BaseASTEntityExtractor):
"""
Extract declared and called entities from C++ code using clang.cindex (libclang),
including classes, namespaces, and methods.
"""
def __init__(self):
self.index = cindex.Index.create()
self.reset()
def reset(self) -> None:
self.declared_entities = []
self.called_entities = []
self.scope_stack = []
def _qualified(self, name: str) -> str:
"""Return fully qualified name using current scope stack."""
if not name:
return ""
if not self.scope_stack:
return name
return "::".join(self.scope_stack + [name])
def _walk_cursor(self, cursor, source_file: str):
for c in cursor.get_children():
# --- Include directives ---
# Note: INCLUSION_DIRECTIVE nodes are at the root level and need special handling
if c.kind == cindex.CursorKind.INCLUSION_DIRECTIVE:
# Get the included file name
included_file = c.displayname
if included_file:
self.called_entities.append(included_file)
continue
kind = c.kind
# --- Namespace --- (process before location check)
if kind == cindex.CursorKind.NAMESPACE:
if c.spelling: # Only add non-empty namespace names
self.scope_stack.append(c.spelling)
self._walk_cursor(c, source_file)
if c.spelling:
self.scope_stack.pop()
continue
# Check location for other node types
loc = c.location
# Skip nodes from other files, but allow nodes without location info
if loc.file and os.path.abspath(loc.file.name) != os.path.abspath(source_file):
continue
# --- Class / Struct ---
if kind in (cindex.CursorKind.CLASS_DECL, cindex.CursorKind.STRUCT_DECL):
# Only process if it has a name
if c.spelling:
# Check if it's a definition (not a forward declaration)
is_def = c.is_definition() if hasattr(c, 'is_definition') else True
if is_def:
full_name = self._qualified(c.spelling)
self.declared_entities.append({"name": full_name, "type": "class"})
# Handle base classes (inheritance)
for base in c.get_children():
if base.kind == cindex.CursorKind.CXX_BASE_SPECIFIER:
if base.spelling:
self.called_entities.append(base.spelling)
self.scope_stack.append(c.spelling)
self._walk_cursor(c, source_file)
self.scope_stack.pop()
continue
# --- Methods ---
if kind in (cindex.CursorKind.CXX_METHOD, cindex.CursorKind.CONSTRUCTOR, cindex.CursorKind.DESTRUCTOR):
if c.spelling: # Only process if it has a name
full_name = self._qualified(c.spelling)
self.declared_entities.append({"name": full_name, "type": "method"})
for p in c.get_arguments():
if p.spelling: # Only add parameters with names
self.declared_entities.append({
"name": f"{full_name}.{p.spelling}",
"type": "variable",
"dtype": p.type.spelling
})
self._walk_cursor(c, source_file)
continue
# --- Free functions ---
if kind == cindex.CursorKind.FUNCTION_DECL:
if c.spelling: # Only process if it has a name
full_name = self._qualified(c.spelling)
self.declared_entities.append({"name": full_name, "type": "function"})
for p in c.get_arguments():
if p.spelling: # Only add parameters with names
self.declared_entities.append({
"name": f"{full_name}.{p.spelling}",
"type": "variable",
"dtype": p.type.spelling
})
self._walk_cursor(c, source_file)
continue
# --- Variables ---
if kind == cindex.CursorKind.VAR_DECL:
full_name = self._qualified(c.spelling)
self.declared_entities.append({
"name": full_name,
"type": "variable",
"dtype": c.type.spelling
})
# Look for TYPE_REF children which explicitly reference the type
# This is more reliable than c.type.spelling when includes aren't resolved
type_ref_found = False
for child in c.get_children():
if child.kind == cindex.CursorKind.TYPE_REF:
# TYPE_REF.spelling gives us the fully qualified type name
# It may have 'class ' or 'struct ' prefix, so strip it
if child.spelling:
type_name = child.spelling.replace('class ', '').replace('struct ', '').strip()
if type_name:
# TYPE_REF gives us the canonical name from the definition,
# which includes namespace qualifiers if present.
# We only add this canonical name and rely on alias resolution
# to match unqualified usage (e.g., 'Calculator' -> 'math::Calculator')
self.called_entities.append(type_name)
type_ref_found = True
break
# Fallback: use c.type.spelling if no TYPE_REF found
# Note: c.type.spelling may give us the name as written in source code,
# which could be unqualified even if it refers to a namespaced type
if not type_ref_found and c.type.spelling:
# Extract the base type name (remove const, &, *, etc.)
type_name = c.type.spelling.strip()
# Remove common qualifiers
type_name = type_name.replace('const', '').replace('&', '').replace('*', '').strip()
if type_name and not type_name in ['int', 'float', 'double', 'char', 'bool', 'void', 'long', 'short', 'unsigned', 'signed']:
# Only add if not already added via TYPE_REF
# c.type.spelling might give unqualified name even for namespaced types
# We'll add it and let alias resolution handle it
self.called_entities.append(type_name)
# --- Calls ---
if kind == cindex.CursorKind.CALL_EXPR:
callee = None
for child in c.get_children():
if child.kind in (cindex.CursorKind.DECL_REF_EXPR, cindex.CursorKind.MEMBER_REF_EXPR):
callee = child.spelling
break
if callee:
self.called_entities.append(callee)
else:
self.called_entities.append(c.displayname or c.spelling)
# Recurse
self._walk_cursor(c, source_file)
def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]:
self.reset()
# If file_path is provided, use it directly for better include resolution
# Otherwise, create a temporary file
tf_name = None
temp_file = False
if file_path and os.path.exists(file_path):
tf_name = file_path
temp_file = False
else:
with tempfile.NamedTemporaryFile(suffix=".cpp", mode="w+", delete=False) as tf:
tf_name = tf.name
tf.write(code)
tf.flush()
temp_file = True
# Get the directory containing the file for include paths
include_dir = os.path.dirname(tf_name) if tf_name else None
args = ['-std=c++17', '-xc++']
if include_dir:
args.append(f'-I{include_dir}')
try:
tu = self.index.parse(
tf_name,
args=args,
options=cindex.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD
)
except Exception as e:
raise RuntimeError(f"libclang failed to parse C++ translation unit: {e}")
self._walk_cursor(tu.cursor, tf_name)
# Deduplicate
seen_decl = set()
unique_declared = []
for e in self.declared_entities:
key = (e.get("name"), e.get("type"), e.get("dtype", None))
if key not in seen_decl:
unique_declared.append(e)
seen_decl.add(key)
unique_called = list(dict.fromkeys(self.called_entities))
# Only delete if we created a temp file
if temp_file:
try:
os.unlink(tf_name)
except Exception:
pass
return unique_declared, unique_called
class RustEntityExtractor(BaseASTEntityExtractor):
"""
Extract declared and called entities from Rust code using tree-sitter.
Handles structs, enums, traits, functions, methods, and modules.
Also detects API endpoint definitions (Actix-web, Rocket, Axum, Warp).
"""
# HTTP method route macros for Rust web frameworks
ROUTE_MACROS = {
'get', 'post', 'put', 'patch', 'delete', 'head', 'options', # Actix-web, Rocket
'Get', 'Post', 'Put', 'Patch', 'Delete', 'Head', 'Options', # Alternative casing
}
# Route-related macros and functions
ROUTE_PATTERNS = {
'route', # Generic route macro
'web::get', 'web::post', 'web::put', 'web::delete', # Actix-web with web::
}
def __init__(self):
self.parser = Parser()
self.parser.language = Language(ts_rust.language())
self.reset()
def reset(self) -> None:
self.declared_entities = []
self.called_entities = []
self.scope_stack = []
self.api_endpoints: List[Dict[str, Any]] = [] # Track API endpoint definitions
def _qualified(self, name: str) -> str:
"""Return fully qualified name using current scope stack."""
if not name:
return ""
if not self.scope_stack:
return name
return "::".join(self.scope_stack + [name])
def _get_node_text(self, node, code_bytes: bytes) -> str:
"""Extract text content of a node."""
return code_bytes[node.start_byte:node.end_byte].decode('utf8')
def _extract_api_endpoint_from_attributes(self, node, code_bytes: bytes) -> Optional[Dict[str, Any]]:
"""
Extract API endpoint information from Rust function attributes.
Handles patterns like:
- #[get("/users")] # Actix-web, Rocket
- #[post("/users")] # Actix-web, Rocket
- #[route("/users", method="GET")] # Generic route
Note: In tree-sitter Rust AST, attributes appear as PREVIOUS SIBLINGS
of the function_item node, not as children.
"""
# Get the parent node to access siblings
parent = node.parent
if not parent:
return None
# Find the index of current node in parent's children
node_index = None
for i, child in enumerate(parent.children):
if child == node:
node_index = i
break
if node_index is None:
return None
# Look backwards through previous siblings for attribute_item nodes
for i in range(node_index - 1, -1, -1):
sibling = parent.children[i]
# Stop if we hit a non-attribute node (except comments/whitespace)
if sibling.type not in ['attribute_item', 'line_comment', 'block_comment']:
break
if sibling.type == 'attribute_item':
attr_text = self._get_node_text(sibling, code_bytes)
# Match HTTP method macros: #[get("/path")], #[post("/path")], #[post("/path", data = "<var>")], etc.
# The pattern now allows optional additional parameters after the path
method_pattern = r'#\[(get|post|put|patch|delete|head|options)\s*\(\s*"([^"]+)"(?:\s*,.*?)?\s*\)\]'
match = re.search(method_pattern, attr_text, re.IGNORECASE)
if match:
http_method = match.group(1).upper()
endpoint_path = match.group(2)
return {
"endpoint": endpoint_path,
"methods": [http_method],
"type": "api_endpoint_definition"
}
# Match generic route macro: #[route("/path", method="GET")]
route_pattern = r'#\[route\s*\(\s*"([^"]+)"(?:.*?method\s*=\s*"([^"]+)")?\s*\)\]'
match = re.search(route_pattern, attr_text, re.IGNORECASE)
if match:
endpoint_path = match.group(1)
http_method = match.group(2).upper() if match.group(2) else "GET"
return {
"endpoint": endpoint_path,
"methods": [http_method],
"type": "api_endpoint_definition"
}
return None
def _walk_tree(self, node, code_bytes: bytes):
"""Recursively walk the tree-sitter AST."""
node_type = node.type
# --- Module declarations ---
if node_type == 'mod_item':
# mod my_module { ... }
name_node = node.child_by_field_name('name')
if name_node:
mod_name = self._get_node_text(name_node, code_bytes)
qualified = self._qualified(mod_name)
self.declared_entities.append({"name": qualified, "type": "module"})
self.scope_stack.append(mod_name)
body = node.child_by_field_name('body')
if body:
for child in body.children:
self._walk_tree(child, code_bytes)
self.scope_stack.pop()
return
# --- Struct declarations ---
elif node_type == 'struct_item':
name_node = node.child_by_field_name('name')
if name_node:
struct_name = self._get_node_text(name_node, code_bytes)
qualified = self._qualified(struct_name)
self.declared_entities.append({"name": qualified, "type": "struct"})
# Check for generic parameters
type_params = node.child_by_field_name('type_parameters')
if type_params:
self._walk_tree(type_params, code_bytes)
self.scope_stack.append(struct_name)
# Process fields
body = node.child_by_field_name('body')
if body:
for child in body.children:
if child.type == 'field_declaration':
field_name_node = child.child_by_field_name('name')
field_type_node = child.child_by_field_name('type')
if field_name_node:
field_name = self._get_node_text(field_name_node, code_bytes)
field_type = self._get_node_text(field_type_node, code_bytes) if field_type_node else "unknown"
self.declared_entities.append({
"name": f"{qualified}.{field_name}",
"type": "field",
"dtype": field_type
})
self.scope_stack.pop()
return
# --- Enum declarations ---
elif node_type == 'enum_item':
name_node = node.child_by_field_name('name')
if name_node:
enum_name = self._get_node_text(name_node, code_bytes)
qualified = self._qualified(enum_name)
self.declared_entities.append({"name": qualified, "type": "enum"})
self.scope_stack.append(enum_name)
body = node.child_by_field_name('body')
if body:
for child in body.children:
if child.type == 'enum_variant':
variant_name_node = child.child_by_field_name('name')
if variant_name_node:
variant_name = self._get_node_text(variant_name_node, code_bytes)
self.declared_entities.append({
"name": f"{qualified}::{variant_name}",
"type": "enum_variant"
})
self.scope_stack.pop()
return
# --- Trait declarations ---
elif node_type == 'trait_item':
name_node = node.child_by_field_name('name')
if name_node:
trait_name = self._get_node_text(name_node, code_bytes)
qualified = self._qualified(trait_name)
self.declared_entities.append({"name": qualified, "type": "trait"})
self.scope_stack.append(trait_name)
body = node.child_by_field_name('body')
if body:
for child in body.children:
self._walk_tree(child, code_bytes)
self.scope_stack.pop()
return
# --- Implementation blocks ---
elif node_type == 'impl_item':
# impl MyStruct { ... } or impl Trait for MyStruct { ... }
type_node = node.child_by_field_name('type')
trait_node = node.child_by_field_name('trait')
impl_name = None
if type_node:
impl_name = self._get_node_text(type_node, code_bytes)
if trait_node:
trait_name = self._get_node_text(trait_node, code_bytes)
self.called_entities.append(trait_name)
if impl_name:
self.scope_stack.append(impl_name)
body = node.child_by_field_name('body')
if body:
for child in body.children:
self._walk_tree(child, code_bytes)
if impl_name:
self.scope_stack.pop()
return
# --- Function declarations ---
elif node_type == 'function_item':
name_node = node.child_by_field_name('name')
if name_node:
func_name = self._get_node_text(name_node, code_bytes)
qualified = self._qualified(func_name)
# Check for API endpoint attributes (e.g., #[get("/users")])
api_info = self._extract_api_endpoint_from_attributes(node, code_bytes)
if api_info:
# This is an API endpoint handler
self.declared_entities.append({
"name": qualified,
"type": "api_endpoint",
"endpoint": api_info.get("endpoint"),
"methods": api_info.get("methods")
})
self.api_endpoints.append({**api_info, "function": qualified})
entity_type = "api_endpoint"
else:
# Determine if this is a method (inside impl block) or free function
entity_type = "method" if len(self.scope_stack) > 0 else "function"
self.declared_entities.append({"name": qualified, "type": entity_type})
# Extract parameters
params = node.child_by_field_name('parameters')
if params:
for child in params.children:
if child.type == 'parameter':
pattern = child.child_by_field_name('pattern')
type_node = child.child_by_field_name('type')
if pattern:
param_name = self._get_node_text(pattern, code_bytes)
param_type = self._get_node_text(type_node, code_bytes) if type_node else "unknown"
# Skip 'self' parameters
if param_name not in ['self', '&self', '&mut self', 'mut self']:
self.declared_entities.append({
"name": f"{qualified}.{param_name}",
"type": "variable",
"dtype": param_type
})
# Walk the function body to find calls
body = node.child_by_field_name('body')
if body:
self._walk_tree(body, code_bytes)
return
# --- Type alias ---
elif node_type == 'type_item':
name_node = node.child_by_field_name('name')
if name_node:
type_name = self._get_node_text(name_node, code_bytes)
qualified = self._qualified(type_name)
self.declared_entities.append({"name": qualified, "type": "type_alias"})
return
# --- Constant declarations ---
elif node_type == 'const_item':
name_node = node.child_by_field_name('name')
type_node = node.child_by_field_name('type')
if name_node:
const_name = self._get_node_text(name_node, code_bytes)
const_type = self._get_node_text(type_node, code_bytes) if type_node else "unknown"
qualified = self._qualified(const_name)
self.declared_entities.append({
"name": qualified,
"type": "constant",
"dtype": const_type
})
# --- Static declarations ---
elif node_type == 'static_item':
name_node = node.child_by_field_name('name')
type_node = node.child_by_field_name('type')
if name_node:
static_name = self._get_node_text(name_node, code_bytes)
static_type = self._get_node_text(type_node, code_bytes) if type_node else "unknown"
qualified = self._qualified(static_name)
self.declared_entities.append({
"name": qualified,
"type": "static",
"dtype": static_type
})
# --- Let bindings (local variables) ---
elif node_type == 'let_declaration':
pattern = node.child_by_field_name('pattern')
type_node = node.child_by_field_name('type')
if pattern and pattern.type == 'identifier':
var_name = self._get_node_text(pattern, code_bytes)
var_type = self._get_node_text(type_node, code_bytes) if type_node else "unknown"
# Only track top-level or module-level variables, not function-local ones
# For now, we skip local variables to avoid clutter
# --- Use declarations (imports) ---
elif node_type == 'use_declaration':
# Extract imported items
use_text = self._get_node_text(node, code_bytes)
self.called_entities.append(use_text)
# --- Call expressions ---
elif node_type == 'call_expression':
function = node.child_by_field_name('function')
if function:
func_text = self._get_node_text(function, code_bytes)
# Clean up function call to get just the name/path
# Handle method calls like obj.method() and path calls like std::vec::Vec::new()
self.called_entities.append(func_text)
# --- Macro invocations ---
elif node_type == 'macro_invocation':
macro_node = node.child_by_field_name('macro')
if macro_node:
macro_name = self._get_node_text(macro_node, code_bytes)
self.called_entities.append(f"{macro_name}!")
# --- Field expressions (method calls or field access) ---
elif node_type == 'field_expression':
field = node.child_by_field_name('field')
if field:
field_name = self._get_node_text(field, code_bytes)
# This could be a field access or method call, record it
# We don't have full context here, so just record the field name
# Recursively walk all children
for child in node.children:
self._walk_tree(child, code_bytes)
def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]:
"""Extract entities from Rust code using tree-sitter."""
self.reset()
code_bytes = code.encode('utf8')
tree = self.parser.parse(code_bytes)
# Walk the AST
self._walk_tree(tree.root_node, code_bytes)
# Deduplicate
seen_decl = set()
unique_declared = []
for e in self.declared_entities:
key = (e.get("name"), e.get("type"), e.get("dtype", None))
if key not in seen_decl:
unique_declared.append(e)
seen_decl.add(key)
unique_called = list(dict.fromkeys(self.called_entities))
return unique_declared, unique_called
class PythonASTEntityExtractor(ast.NodeVisitor, BaseASTEntityExtractor):
"""
AST-based entity extractor for Python code.
Also detects API endpoint definitions (FastAPI, Flask, Django REST Framework).
"""
# Common HTTP decorators/patterns for Python web frameworks
API_DECORATORS = {
'route', # Flask @app.route
'get', 'post', 'put', 'patch', 'delete', 'head', 'options', # FastAPI/Flask methods
'api_view', # DRF @api_view
}
def __init__(self):
self.declared_entities: List[Dict[str, Any]] = []
self.called_entities: List[str] = []
self.current_class: Optional[str] = None
self.current_function: Optional[str] = None
self.api_endpoints: List[Dict[str, Any]] = [] # Track API endpoint definitions
def reset(self) -> None:
"""Clear previous extraction state including context"""
self.declared_entities = []
self.called_entities = []
self.current_class = None
self.current_function = None
self.api_endpoints = []
def _get_type_annotation(self, node: ast.AST) -> str:
"""Extract type annotation from AST node"""
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Constant):
return type(node.value).__name__
elif isinstance(node, ast.Attribute):
return f"{self._get_type_annotation(node.value)}.{node.attr}"
elif isinstance(node, ast.Subscript):
# Handle generic types like List[str], Dict[str, int]
base = self._get_type_annotation(node.value)
if isinstance(node.slice, ast.Tuple):
args = [self._get_type_annotation(elt) for elt in node.slice.elts]
return f"{base}[{', '.join(args)}]"
else:
arg = self._get_type_annotation(node.slice)
return f"{base}[{arg}]"
return "unknown"
def _infer_type_from_value(self, node: ast.AST) -> str:
"""Infer type from assigned value"""
if isinstance(node, ast.Constant):
return type(node.value).__name__
elif isinstance(node, ast.List):
return "list"
elif isinstance(node, ast.Dict):
return "dict"
elif isinstance(node, ast.Set):
return "set"
elif isinstance(node, ast.Tuple):
return "tuple"
elif isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
return node.func.id # Constructor call
elif isinstance(node.func, ast.Attribute):
return "unknown"
elif isinstance(node, ast.Name):
return "unknown" # Reference to another variable
return "unknown"
def visit_ClassDef(self, node: ast.ClassDef):
"""Visit class definitions"""
old_class = self.current_class
self.current_class = node.name
# Add class to declared entities
self.declared_entities.append({
"name": node.name,
"type": "class"
})
# Record base classes as called entities
for base in node.bases:
if isinstance(base, ast.Name):
self.called_entities.append(base.id)
elif isinstance(base, ast.Attribute):
self.called_entities.append(self._get_type_annotation(base))
# Continue visiting child nodes
self.generic_visit(node)
self.current_class = old_class
def visit_FunctionDef(self, node: ast.FunctionDef):
"""Visit function/method definitions and detect API endpoints"""
old_function = self.current_function
if self.current_class:
# This is a method
full_name = f"{self.current_class}.{node.name}"
entity_type = "method"
else:
# This is a function
full_name = node.name
entity_type = "function"
self.current_function = full_name
# Check for API endpoint decorators
api_info = self._extract_api_endpoint_from_decorators(node.decorator_list, full_name)
if api_info:
# Mark this as an API endpoint
self.declared_entities.append({
"name": full_name,
"type": "api_endpoint",
"endpoint": api_info.get("endpoint"),
"methods": api_info.get("methods")
})
self.api_endpoints.append(api_info)
else:
self.declared_entities.append({
"name": full_name,
"type": entity_type
})
# Process parameters
for arg in node.args.args:
if arg.arg == 'self' and self.current_class:
continue # Skip self parameter
dtype = "unknown"
if arg.annotation:
dtype = self._get_type_annotation(arg.annotation)
param_name = f"{full_name}.{arg.arg}" if entity_type == "method" else arg.arg
self.declared_entities.append({
"name": param_name,
"type": "variable",
"dtype": dtype
})
# Continue visiting child nodes
self.generic_visit(node)
self.current_function = old_function
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
"""Visit async function/method definitions"""
# Treat async functions the same as regular functions
self.visit_FunctionDef(node)
def visit_Assign(self, node: ast.Assign):
"""Visit assignment statements"""
# Infer type from the assigned value
dtype = self._infer_type_from_value(node.value)
for target in node.targets:
if isinstance(target, ast.Name):
# Simple variable assignment
var_name = target.id
if self.current_class and self.current_function and self.current_function.startswith(self.current_class):
# Local variable in method
pass # Could add local variables if needed
else:
# Module-level variable
self.declared_entities.append({
"name": var_name,
"type": "variable",
"dtype": dtype
})
elif isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name):
# Attribute assignment like self.name = value
if target.value.id == 'self' and self.current_class:
attr_name = f"{self.current_class}.{target.attr}"
self.declared_entities.append({
"name": attr_name,
"type": "variable",
"dtype": dtype
})
# Continue visiting to catch function calls in the assignment
self.generic_visit(node)
def visit_AnnAssign(self, node: ast.AnnAssign):
"""Visit annotated assignment statements (PEP 526)"""
if isinstance(node.target, ast.Name):
dtype = self._get_type_annotation(node.annotation)
var_name = node.target.id
self.declared_entities.append({
"name": var_name,
"type": "variable",
"dtype": dtype
})
elif isinstance(node.target, ast.Attribute) and isinstance(node.target.value, ast.Name):
if node.target.value.id == 'self' and self.current_class:
dtype = self._get_type_annotation(node.annotation)
attr_name = f"{self.current_class}.{node.target.attr}"
self.declared_entities.append({
"name": attr_name,
"type": "variable",
"dtype": dtype
})
# Continue visiting
if node.value:
self.generic_visit(node)
def visit_Import(self, node: ast.Import):
"""Visit import statements"""
for alias in node.names:
# Record the imported module/package
self.called_entities.append(alias.name)
self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom):
"""Visit from...import statements"""
if node.module:
# Record the module being imported from
self.called_entities.append(node.module)
# Optionally, also record specific imports as module.name
for alias in node.names:
if alias.name != '*':
self.called_entities.append(f"{node.module}.{alias.name}")
else:
# Relative imports without module (from . import x)
for alias in node.names:
if alias.name != '*':
self.called_entities.append(alias.name)
self.generic_visit(node)
def visit_Call(self, node: ast.Call):
"""Visit function/method calls"""
if isinstance(node.func, ast.Name):
# Simple function call
self.called_entities.append(node.func.id)
elif isinstance(node.func, ast.Attribute):
# Method call or attribute access
if isinstance(node.func.value, ast.Name):
# obj.method() - we need to infer the class of obj
# For now, just record the method name
method_name = node.func.attr
# Try to find the variable type from our declared entities
obj_name = node.func.value.id
obj_class = self._find_variable_type(obj_name)
if obj_class and obj_class != "unknown":
self.called_entities.append(f"{obj_class}.{method_name}")
else:
# Fallback: just record the method call
self.called_entities.append(method_name)
elif isinstance(node.func.value, ast.Attribute):
# Nested attribute access like module.Class.method()
full_name = self._get_type_annotation(node.func)
self.called_entities.append(full_name)
# Continue visiting child nodes
self.generic_visit(node)
def _find_variable_type(self, var_name: str) -> str:
"""Find the type of a variable from declared entities"""
for entity in self.declared_entities:
if entity["name"] == var_name and entity["type"] == "variable":
return entity.get("dtype", "unknown")
return "unknown"
def _extract_api_endpoint_from_decorators(self, decorators: List[ast.expr], function_name: str) -> Optional[Dict[str, Any]]:
"""
Extract API endpoint information from function decorators.
Handles patterns like:
- @app.route("/api/users", methods=["GET", "POST"]) # Flask
- @app.get("/api/users") # FastAPI
- @router.post("/api/users") # FastAPI with router
- @api_view(['GET', 'POST']) # Django REST Framework
"""
for decorator in decorators:
# Handle @app.route(...) or @app.get(...)
if isinstance(decorator, ast.Call):
if isinstance(decorator.func, ast.Attribute):
# e.g., app.route, app.get, router.post
method_name = decorator.func.attr.lower()
if method_name in self.API_DECORATORS:
endpoint = None
http_methods = []
# Extract endpoint from first positional argument
if decorator.args and isinstance(decorator.args[0], ast.Constant):
endpoint = decorator.args[0].value
# For FastAPI-style decorators (@app.get, @app.post)
if method_name in {'get', 'post', 'put', 'patch', 'delete', 'head', 'options'}:
http_methods = [method_name.upper()]
# For Flask-style @app.route with methods kwarg
elif method_name == 'route':
for keyword in decorator.keywords:
if keyword.arg == 'methods':
if isinstance(keyword.value, ast.List):
http_methods = [
elt.value for elt in keyword.value.elts
if isinstance(elt, ast.Constant) and isinstance(elt.value, str)
]
if not http_methods:
http_methods = ['GET'] # Flask default
# For DRF @api_view(['GET', 'POST'])
elif method_name == 'api_view':
if decorator.args and isinstance(decorator.args[0], ast.List):
http_methods = [
elt.value for elt in decorator.args[0].elts
if isinstance(elt, ast.Constant) and isinstance(elt.value, str)
]
if endpoint:
return {
"function": function_name,
"endpoint": endpoint,
"methods": http_methods,
"type": "api_endpoint_definition"
}
return None
def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]:
"""
Extract entities from Python code using AST parsing
Args:
code: Python source code as string
file_path: Optional path to the source file (for context)
Returns:
Tuple of (declared_entities, called_entities)
"""
# Ensure fresh state on each extraction
self.reset()
try:
tree = ast.parse(code)
self.visit(tree)
# Remove duplicates while preserving order
seen_declared = set()
unique_declared = []
for entity in self.declared_entities:
key = (entity["name"], entity["type"], entity.get("dtype"))
if key not in seen_declared:
unique_declared.append(entity)
seen_declared.add(key)
unique_called = list(dict.fromkeys(self.called_entities)) # Remove duplicates
return unique_declared, unique_called
except SyntaxError as e:
logger.error(f"Syntax error in Python code: {e}")
return [], []
except Exception as e:
logger.error(f"Error parsing Python code: {e}", exc_info=True)
return [], []
class HybridEntityExtractor:
"""
Hybrid entity extractor that uses AST for known languages,
falls back to LLM for unknown ones
"""
def __init__(self):
self.extractors = {
'py': PythonASTEntityExtractor(),
'c': CEntityExtractor(),
'h': CppEntityExtractor(), # C/C++ headers
'cpp': CppEntityExtractor(),
'cc': CppEntityExtractor(),
'cxx': CppEntityExtractor(),
'hpp': CppEntityExtractor(),
'hxx': CppEntityExtractor(),
'hh': CppEntityExtractor(),
'java': JavaEntityExtractor(),
'js': JavaScriptEntityExtractor(), # βœ… NEW
'jsx': JavaScriptEntityExtractor(), # βœ… NEW
'ts': JavaScriptEntityExtractor(), # TypeScript uses similar AST
'tsx': JavaScriptEntityExtractor(), # TSX similar to JSX
'rs': RustEntityExtractor(),
'html': HTMLEntityExtractor()
}
def _get_language_from_filename(self, file_name: str) -> str:
ext = file_name.split('.')[-1].lower()
return ext
def extract_entities(self, code: str, file_name: str):
lang = self._get_language_from_filename(file_name)
extractor = self.extractors.get(lang)
if extractor:
# Reset the shared extractor instance to ensure no state is carried over
try:
extractor.reset()
except Exception:
# If extractor doesn't implement reset for some reason, ignore and proceed
pass
logger.info(f"Using AST extraction for {lang.upper()} file: {file_name}")
try:
# Try to pass file_name if the extractor supports it (C++ extractor does)
try:
declared_entities, called_entities = extractor.extract_entities(code, file_path=file_name)
except TypeError:
# Fallback for extractors that don't accept file_path parameter
declared_entities, called_entities = extractor.extract_entities(code)
# Add aliases to each declared entity based on file path
for entity in declared_entities:
entity_name = entity.get('name', '')
if entity_name:
aliases = generate_entity_aliases(entity_name, file_name)
entity['aliases'] = aliases
logger.debug(f"Generated aliases for entity '{entity_name}': {aliases}")
return declared_entities, called_entities
except Exception as e:
logger.error(f"Error during AST extraction for file {file_name}: {e}", exc_info=True)
return [], []
else:
raise Exception(f"Using LLM extraction for unsupported language: {file_name}")