Spaces:
Sleeping
Sleeping
File size: 6,553 Bytes
0ad8a10 8de7150 0ad8a10 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | from functools import lru_cache
import json
from enum import Enum
import os
import re
from fuzzywuzzy import fuzz
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.neighbors import KNeighborsClassifier
import joblib
from nltk.stem import WordNetLemmatizer
import nltk
from model.db.db_setup import get_db, ExtractedFile
nltk.download('wordnet')
class MatchMethod(Enum):
FUZZY = "fuzzy"
SUBSTRING = "substring"
ML = "ml"
ML_ERROR = "ml_error"
def get_absolute_path(relative_path):
abs_path=os.path.dirname(os.path.abspath(__file__))
return os.path.join(abs_path, relative_path)
def get_model_path(language):
"""Returns the path to the KNN model file based on language."""
return get_absolute_path(f'model/{language}_knn_model.h5')
def load_functions(language: str) -> dict:
"""
Load functions from database based on language.
Args:
language: The programming language to load functions for.
Returns:
Dictionary of function names and their code.
"""
# Language mapping moved to class constant
LANGUAGE_MAPPING = {
"python": "py",
"javascript": "js",
"typescript": "ts",
"php": "php",
"java": "java"
}
try:
# Get file extension
language = language.lower()
language_ext = LANGUAGE_MAPPING.get(language)
if not language_ext:
raise ValueError(f"Unsupported language: {language}")
# Use context manager for database session
with get_db() as db:
files = ExtractedFile.get_by_extension(db, f".{language_ext}")
# Process files more efficiently
all_functions = {}
for file in files:
data = file.file_data
if isinstance(data, dict):
all_functions.update(data)
elif isinstance(data, list):
all_functions.update({
k: v for d in data
if isinstance(d, dict)
for k, v in d.items()
})
return all_functions
except Exception as e:
print(f"Error loading functions from database: {str(e)}")
return {}
def clean_function_name(function_name, language):
"""
Removes language-specific keywords and strips leading/trailing whitespace using regex.
Args:
function_name: The original function name.
language: The programming language of the function.
Returns:
The cleaned function name.
"""
keywords_dict = {
"python": r"\b(def|class|async|await|for|while|if|else|try|except|finally)\b",
"javascript": r"\b(function|class|async|await|const|let|var|if|else|for|while)\b",
"typescript": r"\b(function|class|async|await|const|let|var|interface|type|enum|if|else|for|while)\b",
"php": r"\b(function|class|public|private|protected|if|else|foreach|while|try|catch|finally)\b",
"java": r"\b(class|public|private|protected|static|final|if|else|for|while|try|catch|finally)\b"
}
pattern = keywords_dict.get(language)
if pattern:
# Remove all language-specific keywords using regex
function_name = re.sub(pattern, "", function_name)
return function_name.strip()
@lru_cache(maxsize=1000)
def preprocess_text(text):
"""Preprocesses the text by converting to lowercase and lemmatizing."""
lemmatizer = WordNetLemmatizer()
words = text.lower().split()
lemmatized_words = [lemmatizer.lemmatize(word) for word in words]
return " ".join(lemmatized_words)
def find_closest_operation_fuzzy(user_input, operations):
"""Finds the closest operation name using fuzzy matching."""
best_match = None
best_score = 0
for operation_name in operations:
score = fuzz.ratio(user_input, operation_name)
if score > best_score:
best_score = score
best_match = operation_name
return best_match if best_score >= 70 else None, MatchMethod.FUZZY
def find_closest_operation_substring(user_input, operations):
"""Finds the closest operation name using substring matching."""
for operation_name in operations:
if user_input in operation_name:
return operation_name, MatchMethod.SUBSTRING
return None, MatchMethod.SUBSTRING
def train_ml_model(operations, language):
"""Trains a KNN model on operation names."""
operation_names = list(operations.keys())
vectorizer = TfidfVectorizer()
X = vectorizer.fit_transform(operation_names)
knn_model = KNeighborsClassifier(n_neighbors=3)
knn_model.fit(X, operation_names)
model_path = get_model_path(language)
joblib.dump((knn_model, vectorizer), model_path)
return knn_model, vectorizer
def load_ml_model(language):
"""Loads the trained KNN model and vectorizer."""
model_path = get_model_path(language)
if os.path.exists(model_path):
return joblib.load(model_path)
else:
return None, None
def predict_operation_name_ml(user_input, knn_model, vectorizer):
"""Predicts the operation name using the trained KNN model."""
input_vector = vectorizer.transform([user_input])
probabilities = knn_model.predict_proba(input_vector)[0]
max_prob = max(probabilities)
"""If the maximum probability is less than 0.5, return an empty string."""
if max_prob < 0.5:
return "", MatchMethod.ML
return knn_model.predict(input_vector)[0], MatchMethod.ML
def get_operation_definition(user_input, language):
operations = load_functions(language)
cleaned_input = clean_function_name(user_input,language)
preprocessed_input = preprocess_text(cleaned_input)
# First, try exact match
closest_match, method = find_closest_operation_substring(preprocessed_input, operations)
if closest_match:
return operations[closest_match]
# Next, try fuzzy matching
closest_match, method = find_closest_operation_fuzzy(preprocessed_input, operations)
if closest_match:
return operations[closest_match]
# Finally, try ML model
knn_model, vectorizer = load_ml_model(language)
if not knn_model:
knn_model, vectorizer = train_ml_model(operations, language)
try:
closest_match, method = predict_operation_name_ml(preprocessed_input, knn_model, vectorizer)
if closest_match:
return operations[closest_match]
except Exception as e:
return ""
return ""
|