himansha2001 commited on
Commit
8953138
·
1 Parent(s): 6224c58

Modified the api file architecture and model loading

Browse files
Files changed (6) hide show
  1. app.py +112 -0
  2. explainer.py +73 -0
  3. features.py +151 -0
  4. main.py +0 -258
  5. model.py +33 -0
  6. requirements.txt +3 -2
app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from fastapi import FastAPI, HTTPException
5
+ from pydantic import BaseModel
6
+ from transformers import AutoTokenizer
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from huggingface_hub import hf_hub_download
9
+ from safetensors.torch import load_file
10
+
11
+ from model import ComplexityFusionModel
12
+ from features import clean_code, get_python_features, get_java_features
13
+ from explainer import generate_shap_explanation
14
+
15
+ # API SETUP
16
+ app = FastAPI(title="Code Complexity XAI API", version="1.0.0")
17
+
18
+ app.add_middleware(
19
+ CORSMiddleware,
20
+ allow_origins=["*"],
21
+ allow_methods=["*"],
22
+ allow_headers=["*"],
23
+ )
24
+
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ label_map = {0: 'CONSTANT', 1: 'LINEAR', 2: 'LOGN', 3: 'NLOGN', 4: 'QUADRATIC', 5: 'CUBIC', 6: 'NP'}
27
+ REPO_ID = "himansha2001/algox"
28
+
29
+ print("Booting up backend services...")
30
+ tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
31
+ model = ComplexityFusionModel(model_name="microsoft/unixcoder-base", num_labels=7, num_static_features=5)
32
+
33
+ safetensors_path = hf_hub_download(repo_id=REPO_ID, filename="model.safetensors")
34
+ state_dict = load_file(safetensors_path)
35
+ model.load_state_dict(state_dict, strict=False)
36
+ model.to(device)
37
+ model.eval()
38
+ print("API is ready for inference!")
39
+
40
+ class CodeRequest(BaseModel):
41
+ code: str
42
+ language: str
43
+
44
+ @app.get("/")
45
+ async def health_check():
46
+ """
47
+ Root endpoint to verify the API is online and the model is loaded.
48
+ """
49
+ return {
50
+ "status": "online",
51
+ "message": "Code Complexity XAI API is running successfully.",
52
+ "model_loaded": True,
53
+ "version": "1.0.0"
54
+ }
55
+
56
+ @app.post("/predict")
57
+ async def predict_complexity(request: CodeRequest):
58
+ """
59
+ Endpoint to predict the complexity of the provided code and generate an explanation.
60
+ """
61
+
62
+ lang = request.language.lower()
63
+
64
+ # Prepare Data
65
+ cleaned_code = clean_code(request.code, lang)
66
+ if lang == 'python':
67
+ feats = get_python_features(request.code)
68
+ elif lang == 'java':
69
+ feats = get_java_features(request.code)
70
+ else:
71
+ raise HTTPException(status_code=400, detail="Language must be 'java' or 'python'")
72
+
73
+ request_static_features = torch.tensor([feats], dtype=torch.float32).to(device)
74
+
75
+ # Tokenize & Predict
76
+ inputs = tokenizer(cleaned_code, return_tensors="pt", truncation=True, max_length=512).to(device)
77
+ with torch.no_grad():
78
+ logits = model(
79
+ input_ids=inputs['input_ids'],
80
+ attention_mask=inputs['attention_mask'],
81
+ static_features=request_static_features
82
+ )
83
+ probs = F.softmax(logits, dim=1)
84
+
85
+ pred_idx = probs.argmax().item()
86
+ confidence = probs.max().item()
87
+ prediction = label_map[pred_idx]
88
+
89
+ # Generate SHAP Explanation
90
+ shap_explanation = generate_shap_explanation(
91
+ cleaned_code=cleaned_code,
92
+ model=model,
93
+ tokenizer=tokenizer,
94
+ static_features_tensor=request_static_features,
95
+ device=device,
96
+ pred_idx=pred_idx,
97
+ label_map=label_map
98
+ )
99
+
100
+ # Return Response
101
+ return {
102
+ "complexity": prediction,
103
+ "confidence": float(confidence),
104
+ "static_features": {
105
+ "max_depth": feats[0],
106
+ "branch_count": feats[1],
107
+ "has_recursion": bool(feats[2]),
108
+ "has_log_math": bool(feats[3]),
109
+ "has_sort": bool(feats[4])
110
+ },
111
+ "shap_explanation": shap_explanation
112
+ }
explainer.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # explainer.py
2
+ import shap
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ def generate_shap_explanation(
7
+ cleaned_code: str,
8
+ model: torch.nn.Module,
9
+ tokenizer,
10
+ static_features_tensor: torch.Tensor,
11
+ device: torch.device,
12
+ pred_idx: int,
13
+ label_map: dict
14
+ ):
15
+ """
16
+ Generates SHAP token importance scores for the predicted complexity class.
17
+ """
18
+
19
+ # SHAP Prediction Wrapper
20
+ def text_prediction_wrapper(texts):
21
+ texts_list = [str(t) for t in texts]
22
+ encodings = tokenizer(
23
+ texts_list,
24
+ return_tensors="pt",
25
+ padding=True,
26
+ truncation=True,
27
+ max_length=512
28
+ ).to(device)
29
+
30
+ # Expand static features to match SHAP permutation batch size
31
+ batch_size = encodings['input_ids'].shape[0]
32
+ expanded_static = static_features_tensor.repeat(batch_size, 1)
33
+
34
+ with torch.no_grad():
35
+ batch_logits = model(
36
+ input_ids=encodings['input_ids'],
37
+ attention_mask=encodings['attention_mask'],
38
+ static_features=expanded_static
39
+ )
40
+ return F.softmax(batch_logits, dim=1).cpu().numpy()
41
+
42
+ # Configure SHAP Explainer
43
+ masker = shap.maskers.Text(tokenizer, mask_token=tokenizer.mask_token)
44
+ explainer = shap.Explainer(
45
+ text_prediction_wrapper,
46
+ masker,
47
+ output_names=list(label_map.values())
48
+ )
49
+
50
+ # Calculate Values (max_evals=100 for API speed)
51
+ shap_values = explainer([cleaned_code], max_evals=100)
52
+
53
+ # Extract the specific tokens and their impact scores for the predicted class
54
+ tokens = shap_values.data[0]
55
+ scores = shap_values.values[0, :, pred_idx]
56
+
57
+ # Map Character Offsets for Frontend Highlighting
58
+ encoding = tokenizer(cleaned_code, return_offsets_mapping=True, truncation=True, max_length=512)
59
+ offsets = encoding["offset_mapping"]
60
+
61
+ token_data = []
62
+ for i, (t, s) in enumerate(zip(tokens, scores)):
63
+ start_char = offsets[i][0] if i < len(offsets) else 0
64
+ end_char = offsets[i][1] if i < len(offsets) else 0
65
+
66
+ token_data.append({
67
+ "token": t,
68
+ "score": float(s),
69
+ "start_char": start_char,
70
+ "end_char": end_char
71
+ })
72
+
73
+ return token_data
features.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import re
3
+ import javalang
4
+
5
+ # Java Cleaner
6
+ def clean_java_code(code):
7
+ code = re.sub(r'//.*', '', code)
8
+ code = re.sub(r'/\*[\s\S]*?\*/', '', code)
9
+ code = re.sub(r'^\s*import\s+.*;', '', code, flags=re.MULTILINE)
10
+ code = re.sub(r'^\s*package\s+.*;', '', code, flags=re.MULTILINE)
11
+ code = re.sub(r'\n\s*\n', '\n', code)
12
+ return code.strip()
13
+
14
+ # Python Cleaner
15
+ def clean_python_code(code):
16
+ code = re.sub(r'#.*', '', code)
17
+ code = re.sub(r'""".*?"""', '', code, flags=re.DOTALL)
18
+ code = re.sub(r"'''.*?'''", '', code, flags=re.DOTALL)
19
+ code = re.sub(r'^\s*import\s+.*', '', code, flags=re.MULTILINE)
20
+ code = re.sub(r'^\s*from\s+.*import.*', '', code, flags=re.MULTILINE)
21
+ code = re.sub(r'\n\s*\n', '\n', code)
22
+ return code.strip()
23
+
24
+
25
+ def clean_code(code, lang):
26
+ if lang == 'python':
27
+ return clean_python_code(code)
28
+ else:
29
+ return clean_java_code(code)
30
+
31
+
32
+ def get_python_features(code):
33
+ try:
34
+ tree = ast.parse(code)
35
+ except:
36
+ return [0, 0, 0, 0, 0]
37
+
38
+ max_depth = 0
39
+ branch_count = 0
40
+ has_recursion = 0
41
+ has_log_math = 0
42
+ has_sort = 0
43
+
44
+ current_functions = []
45
+
46
+ class DepthVisitor(ast.NodeVisitor):
47
+ def __init__(self):
48
+ self.max_depth = 0
49
+ self.current_depth = 0
50
+
51
+ def visit_For(self, node):
52
+ self.current_depth += 1
53
+ self.max_depth = max(self.max_depth, self.current_depth)
54
+ self.generic_visit(node)
55
+ self.current_depth -= 1
56
+
57
+ def visit_While(self, node):
58
+ self.current_depth += 1
59
+ self.max_depth = max(self.max_depth, self.current_depth)
60
+ self.generic_visit(node)
61
+ self.current_depth -= 1
62
+
63
+ def visit_ListComp(self, node):
64
+ self.current_depth += len(node.generators)
65
+ self.max_depth = max(self.max_depth, self.current_depth)
66
+ self.generic_visit(node)
67
+ self.current_depth -= len(node.generators)
68
+
69
+ depth_visitor = DepthVisitor()
70
+ depth_visitor.visit(tree)
71
+ max_depth = depth_visitor.max_depth
72
+
73
+ for node in ast.walk(tree):
74
+ # Branch Counting
75
+ if isinstance(node, (ast.If, ast.While, ast.For, ast.AsyncFor, ast.ListComp)):
76
+ branch_count += 1
77
+
78
+ # Recursion & Sort Detection
79
+ if isinstance(node, ast.FunctionDef):
80
+ current_functions.append(node.name)
81
+
82
+ if isinstance(node, ast.Call):
83
+ # Recursion
84
+ if isinstance(node.func, ast.Name) and node.func.id in current_functions:
85
+ has_recursion = 1
86
+ # Sort Detection: sorted(arr)
87
+ if isinstance(node.func, ast.Name) and node.func.id == 'sorted':
88
+ has_sort = 1
89
+ # Sort Detection: arr.sort()
90
+ if isinstance(node.func, ast.Attribute) and node.func.attr == 'sort':
91
+ has_sort = 1
92
+
93
+ # Logarithmic Math Detection
94
+ if isinstance(node, ast.BinOp):
95
+ if isinstance(node.op, (ast.Div, ast.FloorDiv, ast.RShift, ast.Mult, ast.LShift)):
96
+ has_log_math = 1
97
+ if isinstance(node, ast.AugAssign):
98
+ if isinstance(node.op, (ast.Div, ast.FloorDiv, ast.RShift, ast.Mult, ast.LShift)):
99
+ has_log_math = 1
100
+
101
+ # Return 5 features
102
+ return [max_depth, branch_count, has_recursion, has_log_math, has_sort]
103
+
104
+ def get_java_features(code):
105
+ try:
106
+ if "class " not in code:
107
+ tokens = javalang.tokenizer.tokenize("class Dummy { " + code + " }")
108
+ else:
109
+ tokens = javalang.tokenizer.tokenize(code)
110
+ parser = javalang.parser.Parser(tokens)
111
+ tree = parser.parse_member_declaration()
112
+ except:
113
+ return [0, 0, 0, 0, 0]
114
+
115
+ real_max_depth = 0
116
+ branch_count = 0
117
+ has_recursion = 0
118
+ has_log_math = 0
119
+ has_sort = 0
120
+
121
+ # Max Depth
122
+ for path, node in tree.filter(javalang.tree.ForStatement):
123
+ current = sum(1 for p in path if isinstance(p, (javalang.tree.ForStatement, javalang.tree.WhileStatement, javalang.tree.DoStatement)))
124
+ real_max_depth = max(real_max_depth, current + 1)
125
+
126
+ for path, node in tree.filter(javalang.tree.WhileStatement):
127
+ current = sum(1 for p in path if isinstance(p, (javalang.tree.ForStatement, javalang.tree.WhileStatement, javalang.tree.DoStatement)))
128
+ real_max_depth = max(real_max_depth, current + 1)
129
+
130
+ # Branch Count
131
+ for path, node in tree.filter(javalang.tree.IfStatement):
132
+ branch_count += 1
133
+
134
+ # Recursion & Sorting
135
+ methods = [node.name for path, node in tree.filter(javalang.tree.MethodDeclaration)]
136
+ for path, node in tree.filter(javalang.tree.MethodInvocation):
137
+ if node.member in methods:
138
+ has_recursion = 1
139
+ if node.member == 'sort':
140
+ has_sort = 1
141
+
142
+ # AST-Based Log Math
143
+ for path, node in tree.filter(javalang.tree.BinaryOperation):
144
+ if node.operator in ['/', '*', '>>', '<<', '>>>']:
145
+ has_log_math = 1
146
+
147
+ for path, node in tree.filter(javalang.tree.Assignment):
148
+ if node.type in ['/=', '*=', '>>=', '<<=', '>>>=']:
149
+ has_log_math = 1
150
+
151
+ return [real_max_depth, branch_count, has_recursion, has_log_math, has_sort]
main.py DELETED
@@ -1,258 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import ast
4
- import re
5
- import javalang
6
- import shap
7
- import numpy as np
8
- from fastapi import FastAPI, HTTPException
9
- from pydantic import BaseModel
10
- from transformers import AutoTokenizer, AutoModel
11
- from fastapi.middleware.cors import CORSMiddleware
12
- from huggingface_hub import hf_hub_download
13
-
14
- class ComplexityFusionModel(nn.Module):
15
- def __init__(self, model_name, num_labels, num_static_features):
16
- super().__init__()
17
- self.encoder = AutoModel.from_pretrained(model_name)
18
- hidden_size = self.encoder.config.hidden_size
19
-
20
- self.static_mlp = nn.Sequential(
21
- nn.Linear(num_static_features, 32),
22
- nn.ReLU(),
23
- nn.Linear(32, 32),
24
- nn.ReLU(),
25
- nn.Dropout(0.3)
26
- )
27
-
28
- self.classifier = nn.Sequential(
29
- nn.Linear(hidden_size + 32, 128),
30
- nn.ReLU(),
31
- nn.Dropout(0.3),
32
- nn.Linear(128, num_labels)
33
- )
34
-
35
- def forward(self, input_ids=None, attention_mask=None, static_features=None):
36
- outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
37
- cls_embedding = outputs.last_hidden_state[:, 0, :]
38
- static_vec = self.static_mlp(static_features)
39
-
40
- # Scaling matches your training setup
41
- fused = torch.cat((cls_embedding * 0.5, static_vec * 2.0), dim=1)
42
- logits = self.classifier(fused)
43
- return logits
44
-
45
- def get_python_features(code):
46
- try:
47
- tree = ast.parse(code)
48
- except:
49
- return [0, 0, 0, 0, 0]
50
-
51
- max_depth = 0
52
- branch_count = 0
53
- has_recursion = 0
54
- has_log_math = 0
55
- has_sort = 0
56
-
57
- function_names = []
58
-
59
- class DepthVisitor(ast.NodeVisitor):
60
- def __init__(self):
61
- self.current = 0
62
- self.max_depth = 0
63
-
64
- def visit_For(self, node):
65
- self.current += 1
66
- self.max_depth = max(self.max_depth, self.current)
67
- self.generic_visit(node)
68
- self.current -= 1
69
-
70
- def visit_While(self, node):
71
- self.current += 1
72
- self.max_depth = max(self.max_depth, self.current)
73
- self.generic_visit(node)
74
- self.current -= 1
75
-
76
- def visit_ListComp(self, node):
77
- self.current += len(node.generators)
78
- self.max_depth = max(self.max_depth, self.current)
79
- self.generic_visit(node)
80
- self.current -= len(node.generators)
81
-
82
- dv = DepthVisitor()
83
- dv.visit(tree)
84
- max_depth = dv.max_depth
85
-
86
- for node in ast.walk(tree):
87
- if isinstance(node, (ast.If, ast.For, ast.While, ast.AsyncFor)):
88
- branch_count += 1
89
-
90
- if isinstance(node, ast.FunctionDef):
91
- function_names.append(node.name)
92
-
93
- if isinstance(node, ast.Call):
94
- # recursion detection
95
- if isinstance(node.func, ast.Name) and node.func.id in function_names:
96
- has_recursion = 1
97
-
98
- if isinstance(node.func, ast.Attribute):
99
- if node.func.attr in function_names:
100
- has_recursion = 1
101
-
102
- # sort detection
103
- if isinstance(node.func, ast.Name) and node.func.id == "sorted":
104
- has_sort = 1
105
-
106
- if isinstance(node.func, ast.Attribute) and node.func.attr == "sort":
107
- has_sort = 1
108
-
109
- if isinstance(node, ast.BinOp):
110
- if isinstance(node.op, (ast.Div, ast.FloorDiv, ast.RShift, ast.LShift)):
111
- has_log_math = 1
112
-
113
- return [max_depth, branch_count, has_recursion, has_log_math, has_sort]
114
-
115
- def get_java_features(code):
116
- try:
117
- if "class " not in code:
118
- code = "class Dummy { " + code + " }"
119
-
120
- tokens = javalang.tokenizer.tokenize(code)
121
- parser = javalang.parser.Parser(tokens)
122
- tree = parser.parse_member_declaration()
123
- except:
124
- return [0, 0, 0, 0, 0]
125
-
126
- max_depth = 0
127
- branch_count = 0
128
- has_recursion = 0
129
- has_log_math = 0
130
- has_sort = 0
131
-
132
- methods = [node.name for _, node in tree.filter(javalang.tree.MethodDeclaration)]
133
-
134
- for path, node in tree.filter(javalang.tree.ForStatement):
135
- depth = sum(
136
- isinstance(p, (javalang.tree.ForStatement,
137
- javalang.tree.WhileStatement,
138
- javalang.tree.DoStatement))
139
- for p in path
140
- )
141
- max_depth = max(max_depth, depth + 1)
142
-
143
- for _, node in tree.filter(javalang.tree.IfStatement):
144
- branch_count += 1
145
-
146
- for _, node in tree.filter(javalang.tree.MethodInvocation):
147
- if node.member in methods:
148
- has_recursion = 1
149
-
150
- if node.member in ["sort", "parallelSort"]:
151
- has_sort = 1
152
-
153
- for _, node in tree.filter(javalang.tree.BinaryOperation):
154
- if node.operator in ['/', '>>', '<<', '>>>']:
155
- has_log_math = 1
156
-
157
- return [max_depth, branch_count, has_recursion, has_log_math, has_sort]
158
-
159
- def clean_code(code, lang):
160
- code = re.sub(r'\n\s*\n', '\n', code)
161
- if lang == 'java':
162
- code = re.sub(r'//.*', '', code)
163
- code = re.sub(r'/\*[\s\S]*?\*/', '', code)
164
- return code.strip()
165
-
166
- # API
167
- app = FastAPI()
168
-
169
- app.add_middleware(
170
- CORSMiddleware,
171
- allow_origins=["*"],
172
- allow_methods=["*"],
173
- allow_headers=["*"],
174
- )
175
-
176
- device = torch.device("cpu")
177
- tokenizer = AutoTokenizer.from_pretrained("microsoft/unixcoder-base")
178
- label_map = {0: 'CONSTANT', 1: 'LINEAR', 2: 'LOGN', 3: 'NLOGN', 4: 'QUADRATIC', 5: 'CUBIC', 6: 'NP'}
179
-
180
- print("Downloading model weights...")
181
- model_path = hf_hub_download(repo_id="himansha2001/algox", filename="model.pth")
182
-
183
- print("Loading model...")
184
-
185
- model = ComplexityFusionModel("microsoft/unixcoder-base", 7, 5)
186
- state_dict = torch.load(model_path, map_location=device)
187
- model.load_state_dict(state_dict, strict=False)
188
- model.to(device)
189
- model.eval()
190
- print("Model loaded successfully!")
191
-
192
- class CodeRequest(BaseModel):
193
- code: str
194
- language: str = "java"
195
-
196
- @app.post("/predict")
197
- async def predict_complexity(request: CodeRequest):
198
- lang = request.language.lower()
199
- cleaned_code = clean_code(request.code, lang)
200
-
201
- if lang == 'python':
202
- feats = get_python_features(request.code)
203
- else:
204
- feats = get_java_features(request.code)
205
-
206
- request_static_features = torch.tensor([feats], dtype=torch.float).to(device)
207
-
208
- # Predict
209
- inputs = tokenizer(cleaned_code, return_tensors="pt", truncation=True, max_length=512).to(device)
210
- with torch.no_grad():
211
- logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], static_features=request_static_features)
212
- probs = torch.nn.functional.softmax(logits, dim=1)
213
-
214
- pred_idx = probs.argmax().item()
215
- confidence = probs.max().item()
216
- prediction = label_map[pred_idx]
217
-
218
- # SHAP Wrapper
219
- def text_prediction_wrapper(texts):
220
- encodings = tokenizer(list(texts), return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
221
- batch_size = encodings['input_ids'].shape[0]
222
- expanded_static = request_static_features.repeat(batch_size, 1)
223
- with torch.no_grad():
224
- batch_logits = model(input_ids=encodings['input_ids'], attention_mask=encodings['attention_mask'], static_features=expanded_static)
225
- return torch.nn.functional.softmax(batch_logits, dim=1).cpu().numpy()
226
-
227
- # Explain (SHAP)
228
- masker = shap.maskers.Text(tokenizer, mask_token="<mask>")
229
- explainer = shap.Explainer(text_prediction_wrapper, masker, output_names=list(label_map.values()))
230
-
231
- shap_values = explainer([cleaned_code], max_evals=100)
232
-
233
- tokens = shap_values.data[0]
234
- scores = shap_values.values[0, :, pred_idx]
235
-
236
- encoding = tokenizer(cleaned_code, return_offsets_mapping=True, truncation=True, max_length=512)
237
- offsets = encoding["offset_mapping"]
238
-
239
- token_data = []
240
- for i, (t, s) in enumerate(zip(tokens, scores)):
241
- start_char = offsets[i][0] if i < len(offsets) else 0
242
- end_char = offsets[i][1] if i < len(offsets) else 0
243
-
244
- token_data.append({
245
- "token": t,
246
- "score": float(s),
247
- "start_char": start_char,
248
- "end_char": end_char
249
- })
250
-
251
- return {
252
- "complexity": prediction,
253
- "confidence": float(confidence),
254
- "static_features": {
255
- "depth": feats[0], "branches": feats[1], "recursion": feats[2], "log_hint": feats[3], "has_sort": feats[4]
256
- },
257
- "shap_explanation": token_data
258
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoConfig, AutoModel
4
+ from transformers.modeling_outputs import SequenceClassifierOutput
5
+
6
+ class ComplexityFusionModel(nn.Module):
7
+ def __init__(self, model_name, num_labels, num_static_features, static_hidden_dim=16):
8
+ super(ComplexityFusionModel, self).__init__()
9
+
10
+ # Load config and base model
11
+ self.config = AutoConfig.from_pretrained(model_name)
12
+ self.codebert = AutoModel.from_pretrained(model_name)
13
+
14
+ self.static_mlp = nn.Sequential(
15
+ nn.Linear(num_static_features, static_hidden_dim),
16
+ nn.ReLU(),
17
+ nn.Dropout(0.1)
18
+ )
19
+
20
+ fusion_dim = self.config.hidden_size + static_hidden_dim
21
+ self.classifier = nn.Linear(fusion_dim, num_labels)
22
+
23
+ def forward(self, input_ids=None, attention_mask=None, static_features=None):
24
+ outputs = self.codebert(input_ids=input_ids, attention_mask=attention_mask)
25
+ bert_output = outputs.last_hidden_state[:, 0, :]
26
+
27
+ static_output = self.static_mlp(static_features)
28
+
29
+ combined_features = torch.cat((bert_output, static_output), dim=1)
30
+
31
+ logits = self.classifier(combined_features)
32
+
33
+ return logits
requirements.txt CHANGED
@@ -2,8 +2,9 @@ fastapi
2
  uvicorn
3
  torch
4
  transformers
5
- numpy
 
6
  shap
7
  javalang
8
  huggingface_hub
9
- pydantic
 
2
  uvicorn
3
  torch
4
  transformers
5
+ safetensors
6
+ numpy<2.0.0
7
  shap
8
  javalang
9
  huggingface_hub
10
+ pydantic