GaneshNaiknavare commited on
Commit
a18f446
·
verified ·
1 Parent(s): c600281

Create multi_layer_operation_predictor/operation_predictor.py

Browse files
multi_layer_operation_predictor/operation_predictor.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+ import json
3
+ from enum import Enum
4
+ import os
5
+ import re
6
+ from fuzzywuzzy import fuzz
7
+ from sklearn.feature_extraction.text import TfidfVectorizer
8
+ from sklearn.neighbors import KNeighborsClassifier
9
+ import joblib
10
+ from nltk.stem import WordNetLemmatizer
11
+ import nltk
12
+
13
+ nltk.download('wordnet')
14
+
15
+ class MatchMethod(Enum):
16
+ FUZZY = "fuzzy"
17
+ SUBSTRING = "substring"
18
+ ML = "ml"
19
+ ML_ERROR = "ml_error"
20
+
21
+ def get_absolute_path(relative_path):
22
+ abs_path=os.path.dirname(os.path.abspath(__file__))
23
+ return os.path.join(abs_path, relative_path)
24
+
25
+ def get_model_path(language):
26
+ """Returns the path to the KNN model file based on language."""
27
+ return get_absolute_path(f'model/{language}_knn_model.h5')
28
+
29
+ def load_functions(language):
30
+ lang_dir = get_absolute_path(f"data/{language}")
31
+ file_path = os.path.join(lang_dir, 'functions.json')
32
+
33
+ if os.path.exists(file_path):
34
+ with open(file_path, 'r') as file:
35
+ return json.load(file)
36
+ else:
37
+ raise FileNotFoundError(f"Function definitions for language '{language}' not found.")
38
+
39
+ def clean_function_name(function_name, language):
40
+ """
41
+ Removes language-specific keywords and strips leading/trailing whitespace using regex.
42
+
43
+ Args:
44
+ function_name: The original function name.
45
+ language: The programming language of the function.
46
+ Returns:
47
+ The cleaned function name.
48
+ """
49
+ keywords_dict = {
50
+ "python": r"\b(def|class|async|await|for|while|if|else|try|except|finally)\b",
51
+ "javascript": r"\b(function|class|async|await|const|let|var|if|else|for|while)\b",
52
+ "typescript": r"\b(function|class|async|await|const|let|var|interface|type|enum|if|else|for|while)\b",
53
+ "php": r"\b(function|class|public|private|protected|if|else|foreach|while|try|catch|finally)\b",
54
+ "java": r"\b(class|public|private|protected|static|final|if|else|for|while|try|catch|finally)\b"
55
+ }
56
+ pattern = keywords_dict.get(language)
57
+ if pattern:
58
+ # Remove all language-specific keywords using regex
59
+ function_name = re.sub(pattern, "", function_name)
60
+ return function_name.strip()
61
+
62
+ @lru_cache(maxsize=1000)
63
+ def preprocess_text(text):
64
+ """Preprocesses the text by converting to lowercase and lemmatizing."""
65
+ lemmatizer = WordNetLemmatizer()
66
+ words = text.lower().split()
67
+ lemmatized_words = [lemmatizer.lemmatize(word) for word in words]
68
+ return " ".join(lemmatized_words)
69
+
70
+ def load_functions(language):
71
+ """Loads functions from a JSON file based on the language."""
72
+ file_name = f'{language}_functions.json'
73
+ file_path = get_absolute_path(f'data/{file_name}')
74
+ with open(file_path, 'r') as file:
75
+ return json.load(file)
76
+
77
+ def find_closest_operation_fuzzy(user_input, operations):
78
+ """Finds the closest operation name using fuzzy matching."""
79
+ best_match = None
80
+ best_score = 0
81
+ for operation_name in operations:
82
+ score = fuzz.ratio(user_input, operation_name)
83
+ if score > best_score:
84
+ best_score = score
85
+ best_match = operation_name
86
+ return best_match if best_score >= 70 else None, MatchMethod.FUZZY
87
+
88
+ def find_closest_operation_substring(user_input, operations):
89
+ """Finds the closest operation name using substring matching."""
90
+ for operation_name in operations:
91
+ if user_input in operation_name:
92
+ return operation_name, MatchMethod.SUBSTRING
93
+ return None, MatchMethod.SUBSTRING
94
+
95
+ def train_ml_model(operations, language):
96
+ """Trains a KNN model on operation names."""
97
+ operation_names = list(operations.keys())
98
+ vectorizer = TfidfVectorizer()
99
+ X = vectorizer.fit_transform(operation_names)
100
+ knn_model = KNeighborsClassifier(n_neighbors=3)
101
+ knn_model.fit(X, operation_names)
102
+ model_path = get_model_path(language)
103
+ joblib.dump((knn_model, vectorizer), model_path)
104
+ return knn_model, vectorizer
105
+
106
+ def load_ml_model(language):
107
+ """Loads the trained KNN model and vectorizer."""
108
+ model_path = get_model_path(language)
109
+ if os.path.exists(model_path):
110
+ return joblib.load(model_path)
111
+ else:
112
+ return None, None
113
+
114
+ def predict_operation_name_ml(user_input, knn_model, vectorizer):
115
+ """Predicts the operation name using the trained KNN model."""
116
+ input_vector = vectorizer.transform([user_input])
117
+ probabilities = knn_model.predict_proba(input_vector)[0]
118
+ max_prob = max(probabilities)
119
+ """If the maximum probability is less than 0.5, return an empty string."""
120
+ if max_prob < 0.5:
121
+ return "", MatchMethod.ML
122
+ return knn_model.predict(input_vector)[0], MatchMethod.ML
123
+
124
+ def get_operation_definition(user_input, language):
125
+ operations = load_functions(language)
126
+ cleaned_input = clean_function_name(user_input,language)
127
+ preprocessed_input = preprocess_text(cleaned_input)
128
+
129
+ # First, try exact match
130
+ closest_match, method = find_closest_operation_substring(preprocessed_input, operations)
131
+ if closest_match:
132
+ return operations[closest_match]
133
+
134
+ # Next, try fuzzy matching
135
+ closest_match, method = find_closest_operation_fuzzy(preprocessed_input, operations)
136
+ if closest_match:
137
+ return operations[closest_match]
138
+ # Finally, try ML model
139
+ knn_model, vectorizer = load_ml_model(language)
140
+ if not knn_model:
141
+ knn_model, vectorizer = train_ml_model(operations, language)
142
+
143
+ try:
144
+ closest_match, method = predict_operation_name_ml(preprocessed_input, knn_model, vectorizer)
145
+ if closest_match:
146
+ return operations[closest_match]
147
+ except Exception as e:
148
+ return ""
149
+
150
+ return ""