Update main.py
Browse files
main.py
CHANGED
|
@@ -90,42 +90,15 @@ except Exception as e:
|
|
| 90 |
@app.route("/", methods=['GET'])
|
| 91 |
def home():
|
| 92 |
return jsonify({
|
| 93 |
-
"message": "CodeBERT Vulnerability
|
| 94 |
"status": "Model loaded" if model is not None else "Model not loaded",
|
| 95 |
"device": str(device) if device else "unknown",
|
| 96 |
"endpoints": {
|
| 97 |
-
"/predict": "POST with JSON body containing '
|
| 98 |
-
"/predict_batch": "POST with JSON body containing 'codes' array",
|
| 99 |
-
"/predict_get": "GET with 'code' URL parameter"
|
| 100 |
}
|
| 101 |
})
|
| 102 |
|
| 103 |
@app.route("/predict", methods=['POST'])
|
| 104 |
-
def predict_post():
|
| 105 |
-
try:
|
| 106 |
-
if model is None or tokenizer is None:
|
| 107 |
-
return jsonify({"error": "Model not loaded properly"}), 500
|
| 108 |
-
|
| 109 |
-
data = request.get_json()
|
| 110 |
-
if not data or 'code' not in data:
|
| 111 |
-
return jsonify({"error": "Missing 'code' field in JSON body"}), 400
|
| 112 |
-
|
| 113 |
-
code = data['code']
|
| 114 |
-
if not code or not isinstance(code, str):
|
| 115 |
-
return jsonify({"error": "'code' field must be a non-empty string"}), 400
|
| 116 |
-
|
| 117 |
-
score = predict_vulnerability(code)
|
| 118 |
-
|
| 119 |
-
return jsonify({
|
| 120 |
-
"score": score,
|
| 121 |
-
"vulnerability_level": get_vulnerability_level(score),
|
| 122 |
-
"code_preview": code[:200] + "..." if len(code) > 200 else code
|
| 123 |
-
})
|
| 124 |
-
|
| 125 |
-
except Exception as e:
|
| 126 |
-
return jsonify({"error": f"Prediction error: {str(e)}"}), 500
|
| 127 |
-
|
| 128 |
-
@app.route("/predict_batch", methods=['POST'])
|
| 129 |
def predict_batch():
|
| 130 |
try:
|
| 131 |
if model is None or tokenizer is None:
|
|
@@ -148,10 +121,7 @@ def predict_batch():
|
|
| 148 |
|
| 149 |
for j, score in enumerate(scores):
|
| 150 |
results.append({
|
| 151 |
-
"
|
| 152 |
-
"score": score,
|
| 153 |
-
"vulnerability_level": get_vulnerability_level(score),
|
| 154 |
-
"code_preview": batch[j][:100] + "..." if len(batch[j]) > 100 else batch[j]
|
| 155 |
})
|
| 156 |
|
| 157 |
return jsonify({"results": results})
|
|
@@ -159,26 +129,7 @@ def predict_batch():
|
|
| 159 |
except Exception as e:
|
| 160 |
return jsonify({"error": f"Batch prediction error: {str(e)}"}), 500
|
| 161 |
|
| 162 |
-
|
| 163 |
-
def predict_get():
|
| 164 |
-
try:
|
| 165 |
-
if model is None or tokenizer is None:
|
| 166 |
-
return jsonify({"error": "Model not loaded properly"}), 500
|
| 167 |
-
|
| 168 |
-
code = request.args.get("code")
|
| 169 |
-
if not code:
|
| 170 |
-
return jsonify({"error": "Missing 'code' URL parameter"}), 400
|
| 171 |
-
|
| 172 |
-
score = predict_vulnerability(code)
|
| 173 |
-
|
| 174 |
-
return jsonify({
|
| 175 |
-
"score": score,
|
| 176 |
-
"vulnerability_level": get_vulnerability_level(score),
|
| 177 |
-
"code_preview": code[:200] + "..." if len(code) > 200 else code
|
| 178 |
-
})
|
| 179 |
-
|
| 180 |
-
except Exception as e:
|
| 181 |
-
return jsonify({"error": f"Prediction error: {str(e)}"}), 500
|
| 182 |
|
| 183 |
def predict_vulnerability(code):
|
| 184 |
dynamic_length = min(max(len(code.split()) * 2, 128), 512)
|
|
@@ -229,13 +180,6 @@ def predict_vulnerability_batch(codes):
|
|
| 229 |
|
| 230 |
return [round(float(score), 4) for score in scores.flatten()]
|
| 231 |
|
| 232 |
-
def get_vulnerability_level(score):
|
| 233 |
-
if score < 0.3:
|
| 234 |
-
return "Low"
|
| 235 |
-
elif score < 0.7:
|
| 236 |
-
return "Medium"
|
| 237 |
-
else:
|
| 238 |
-
return "High"
|
| 239 |
|
| 240 |
@app.route("/health", methods=['GET'])
|
| 241 |
def health_check():
|
|
|
|
| 90 |
@app.route("/", methods=['GET'])
|
| 91 |
def home():
|
| 92 |
return jsonify({
|
| 93 |
+
"message": "CodeBERT Vulnerability Evalutor API",
|
| 94 |
"status": "Model loaded" if model is not None else "Model not loaded",
|
| 95 |
"device": str(device) if device else "unknown",
|
| 96 |
"endpoints": {
|
| 97 |
+
"/predict": "POST with JSON body containing 'codes' array"
|
|
|
|
|
|
|
| 98 |
}
|
| 99 |
})
|
| 100 |
|
| 101 |
@app.route("/predict", methods=['POST'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
def predict_batch():
|
| 103 |
try:
|
| 104 |
if model is None or tokenizer is None:
|
|
|
|
| 121 |
|
| 122 |
for j, score in enumerate(scores):
|
| 123 |
results.append({
|
| 124 |
+
"score": score
|
|
|
|
|
|
|
|
|
|
| 125 |
})
|
| 126 |
|
| 127 |
return jsonify({"results": results})
|
|
|
|
| 129 |
except Exception as e:
|
| 130 |
return jsonify({"error": f"Batch prediction error: {str(e)}"}), 500
|
| 131 |
|
| 132 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
def predict_vulnerability(code):
|
| 135 |
dynamic_length = min(max(len(code.split()) * 2, 128), 512)
|
|
|
|
| 180 |
|
| 181 |
return [round(float(score), 4) for score in scores.flatten()]
|
| 182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
@app.route("/health", methods=['GET'])
|
| 185 |
def health_check():
|