File size: 6,553 Bytes
0ad8a10
 
 
 
 
 
 
 
 
 
 
8de7150
0ad8a10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
from functools import lru_cache
import json
from enum import Enum
import os
import re
from fuzzywuzzy import fuzz
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.neighbors import KNeighborsClassifier
import joblib
from nltk.stem import WordNetLemmatizer
import nltk
from model.db.db_setup import get_db, ExtractedFile

nltk.download('wordnet')

class MatchMethod(Enum):
    FUZZY = "fuzzy"
    SUBSTRING = "substring"
    ML = "ml"
    ML_ERROR = "ml_error"

def get_absolute_path(relative_path):
    abs_path=os.path.dirname(os.path.abspath(__file__))
    return os.path.join(abs_path, relative_path)

def get_model_path(language):
    """Returns the path to the KNN model file based on language."""
    return get_absolute_path(f'model/{language}_knn_model.h5')

def load_functions(language: str) -> dict:
    """
    Load functions from database based on language.
    
    Args:
        language: The programming language to load functions for.
    Returns:
        Dictionary of function names and their code.
    """
    # Language mapping moved to class constant
    LANGUAGE_MAPPING = {
        "python": "py",
        "javascript": "js",
        "typescript": "ts",
        "php": "php",
        "java": "java"
    }

    try:
        # Get file extension
        language = language.lower()
        language_ext = LANGUAGE_MAPPING.get(language)
        if not language_ext:
            raise ValueError(f"Unsupported language: {language}")

        # Use context manager for database session
        with get_db() as db:
            files = ExtractedFile.get_by_extension(db, f".{language_ext}")
            
            # Process files more efficiently
            all_functions = {}
            for file in files:
                data = file.file_data
                if isinstance(data, dict):
                    all_functions.update(data)
                elif isinstance(data, list):
                    all_functions.update({
                        k: v for d in data 
                        if isinstance(d, dict) 
                        for k, v in d.items()
                    })

            return all_functions

    except Exception as e:
        print(f"Error loading functions from database: {str(e)}")
        return {}

def clean_function_name(function_name, language):
    """
    Removes language-specific keywords and strips leading/trailing whitespace using regex.
    
    Args:
        function_name: The original function name.
        language: The programming language of the function.
    Returns:
        The cleaned function name.
    """
    keywords_dict = {
        "python": r"\b(def|class|async|await|for|while|if|else|try|except|finally)\b",
        "javascript": r"\b(function|class|async|await|const|let|var|if|else|for|while)\b",
        "typescript": r"\b(function|class|async|await|const|let|var|interface|type|enum|if|else|for|while)\b",
        "php": r"\b(function|class|public|private|protected|if|else|foreach|while|try|catch|finally)\b",
        "java": r"\b(class|public|private|protected|static|final|if|else|for|while|try|catch|finally)\b"
    }
    pattern = keywords_dict.get(language)
    if pattern:
        # Remove all language-specific keywords using regex
        function_name = re.sub(pattern, "", function_name)
    return function_name.strip()

@lru_cache(maxsize=1000)
def preprocess_text(text):
    """Preprocesses the text by converting to lowercase and lemmatizing."""
    lemmatizer = WordNetLemmatizer()
    words = text.lower().split()
    lemmatized_words = [lemmatizer.lemmatize(word) for word in words]
    return " ".join(lemmatized_words)

def find_closest_operation_fuzzy(user_input, operations):
    """Finds the closest operation name using fuzzy matching."""
    best_match = None
    best_score = 0
    for operation_name in operations:
        score = fuzz.ratio(user_input, operation_name)
        if score > best_score:
            best_score = score
            best_match = operation_name
    return best_match if best_score >= 70 else None, MatchMethod.FUZZY

def find_closest_operation_substring(user_input, operations):
    """Finds the closest operation name using substring matching."""
    for operation_name in operations:
        if user_input in operation_name:
            return operation_name, MatchMethod.SUBSTRING
    return None, MatchMethod.SUBSTRING

def train_ml_model(operations, language):
    """Trains a KNN model on operation names."""
    operation_names = list(operations.keys())
    vectorizer = TfidfVectorizer()
    X = vectorizer.fit_transform(operation_names)
    knn_model = KNeighborsClassifier(n_neighbors=3)
    knn_model.fit(X, operation_names)
    model_path = get_model_path(language)
    joblib.dump((knn_model, vectorizer), model_path)
    return knn_model, vectorizer

def load_ml_model(language):
    """Loads the trained KNN model and vectorizer."""
    model_path = get_model_path(language)
    if os.path.exists(model_path):
        return joblib.load(model_path)
    else:
        return None, None

def predict_operation_name_ml(user_input, knn_model, vectorizer):
    """Predicts the operation name using the trained KNN model."""
    input_vector = vectorizer.transform([user_input])
    probabilities = knn_model.predict_proba(input_vector)[0]
    max_prob = max(probabilities)
    """If the maximum probability is less than 0.5, return an empty string."""
    if max_prob < 0.5:
        return "", MatchMethod.ML
    return knn_model.predict(input_vector)[0], MatchMethod.ML

def get_operation_definition(user_input, language):
    operations = load_functions(language)
    cleaned_input = clean_function_name(user_input,language)
    preprocessed_input = preprocess_text(cleaned_input)

    # First, try exact match
    closest_match, method = find_closest_operation_substring(preprocessed_input, operations)
    if closest_match:
        return operations[closest_match]
        
    # Next, try fuzzy matching
    closest_match, method = find_closest_operation_fuzzy(preprocessed_input, operations)
    if closest_match:
        return operations[closest_match]
    # Finally, try ML model
    knn_model, vectorizer = load_ml_model(language)
    if not knn_model:
        knn_model, vectorizer = train_ml_model(operations, language)

    try:
        closest_match, method = predict_operation_name_ml(preprocessed_input, knn_model, vectorizer)
        if closest_match:
            return operations[closest_match]
    except Exception as e:
        return ""
    
    return ""