SmartContractAudit / utils /propertyretriever.py
ajaxwin
fix: Update file paths and ensure model loading in PropertyRetriever
45bd962
"""
PropertyRetriever class for finding similar properties based on code embeddings.
It loads a dataset of properties, computes embeddings for their critical code sections,
and provides a method to retrieve the most similar property given a new code snippet.
"""
# ! Current data contains properties from contracts.json, making it more likely to find a exact match
import pandas as pd
import numpy as np
from sklearn.preprocessing import normalize
from data.data_loader import DEFAULT_CSV_PATH
from dotenv import dotenv_values
SIMILARITY_THRESHOLD = 0.8 # Adjust as needed based on validation
# -------------------------------------------------------------------
# 1. Load the dataset and build the vector database (offline/once)
# -------------------------------------------------------------------
class PropertyRetriever:
def __init__(self):
"""
csv_path : path to the CSV file containing the columns:
SpecHash, SpecIndex, Type, Name, StartLine, EndLine,
MethodsInRule, RuleContent, RelatedFunctions,
FunctionBodies, FilePath, Project, ContractCode,
StateVarAssignment, RuleContentNL, Funcitonality
similarity_threshold : minimum dot product to consider a match
"""
self.df = pd.read_csv(DEFAULT_CSV_PATH)
self.threshold = SIMILARITY_THRESHOLD
self.embedder = None
def load_model(self):
"""Use a lightweight, open‑source embedding model."""
if self.embedder is not None:
from sentence_transformers import SentenceTransformer
self.embedder = SentenceTransformer(
'all-MiniLM-L6-v2',
use_auth_token=dotenv_values(".env").get('HF_TOKEN', '')
)
# Extract "critical code" from each property (use FunctionBodies)
# Fallback to RelatedFunctions or RuleContent if FunctionBodies is missing
self.critical_codes = []
for idx, row in self.df.iterrows():
code = row.get('FunctionBodies', '')
if pd.isna(code) or code.strip() == '':
# Fallback: concatenate RelatedFunctions or use RuleContent
code = row.get('RelatedFunctions', '')
if pd.isna(code) or code.strip() == '':
code = row.get('RuleContent', '')
self.critical_codes.append(str(code))
# Compute embeddings for all critical codes
self.embeddings = self.embedder.encode(self.critical_codes, show_progress_bar=True) #type: ignore
# Normalize for dot product = cosine similarity
self.embeddings = normalize(self.embeddings, norm='l2')
def get_similar_property(self, input_code: str) -> str:
"""
Given a Solidity function code string, return the most similar property
(RuleContent) from the dataset, or an empty string if none exceeds the threshold.
"""
if not input_code or not isinstance(input_code, str):
return ""
# Step ②: Embed the subject code
query_emb = self.embedder.encode([input_code]) #type: ignore
query_emb = normalize(query_emb, norm='l2')
# Step ③: Compute dot products with all database vectors
similarities = np.dot(self.embeddings, query_emb.T).flatten()
# Find the best match above threshold
best_idx = np.argmax(similarities)
best_score = similarities[best_idx]
if best_score >= self.threshold:
# Return the property content (RuleContent) of the best match
return self.df.iloc[best_idx]['RuleContentNL']
else:
# No sufficiently similar property found
return ""