| from flask import Flask, request, jsonify |
| import torch |
| from transformers import RobertaTokenizer, RobertaForSequenceClassification |
| import os |
| from functools import lru_cache |
|
|
| app = Flask(__name__) |
|
|
| model = None |
| tokenizer = None |
| device = None |
|
|
| def setup_device(): |
| if torch.cuda.is_available(): |
| return torch.device('cuda') |
| elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): |
| return torch.device('mps') |
| else: |
| return torch.device('cpu') |
|
|
| def load_tokenizer(): |
| try: |
| tokenizer = RobertaTokenizer.from_pretrained('./tokenizer_vulnerability') |
| tokenizer.model_max_length = 512 |
| return tokenizer |
| except Exception as e: |
| print(f"Error loading tokenizer: {e}") |
| return RobertaTokenizer.from_pretrained('microsoft/codebert-base') |
|
|
| def load_model(): |
| global device |
| device = setup_device() |
| print(f"Using device: {device}") |
| |
| try: |
| checkpoint = torch.load("codebert_vulnerability_scorer.pth", map_location=device) |
| |
| if 'config' in checkpoint: |
| from transformers import RobertaConfig |
| config = RobertaConfig.from_dict(checkpoint['config']) |
| model = RobertaForSequenceClassification(config) |
| else: |
| model = RobertaForSequenceClassification.from_pretrained( |
| 'microsoft/codebert-base', |
| num_labels=1 |
| ) |
| |
| if 'model_state_dict' in checkpoint: |
| model.load_state_dict(checkpoint['model_state_dict']) |
| else: |
| model.load_state_dict(checkpoint) |
| |
| model.to(device) |
| model.eval() |
| |
| if device.type == 'cuda': |
| model.half() |
| |
| return model |
| |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| raise e |
|
|
| @lru_cache(maxsize=1000) |
| def cached_tokenize(code_hash, max_length): |
| code = code_hash |
| return tokenizer( |
| code, |
| truncation=True, |
| padding='max_length', |
| max_length=max_length, |
| return_tensors='pt' |
| ) |
|
|
| try: |
| print("Loading tokenizer...") |
| tokenizer = load_tokenizer() |
| print("Tokenizer loaded successfully!") |
| |
| print("Loading model...") |
| model = load_model() |
| print("Model loaded successfully!") |
| |
| except Exception as e: |
| print(f"Error during initialization: {str(e)}") |
| tokenizer = None |
| model = None |
|
|
| @app.route("/", methods=['GET']) |
| def home(): |
| return jsonify({ |
| "message": "CodeBERT Vulnerability Evalutor API", |
| "status": "Model loaded" if model is not None else "Model not loaded", |
| "device": str(device) if device else "unknown", |
| "endpoints": { |
| "/predict": "POST with JSON body containing 'codes' array" |
| } |
| }) |
|
|
| @app.route("/predict", methods=['POST']) |
| def predict_batch(): |
| try: |
| if model is None or tokenizer is None: |
| return jsonify({"error": "Model not loaded properly"}), 500 |
| |
| data = request.get_json() |
| if not data or 'codes' not in data: |
| return jsonify({"error": "Missing 'codes' field in JSON body"}), 400 |
| |
| codes = data['codes'] |
| if not isinstance(codes, list) or len(codes) == 0: |
| return jsonify({"error": "'codes' must be a non-empty array"}), 400 |
| |
| batch_size = min(len(codes), 16) |
| results = [] |
| |
| for i in range(0, len(codes), batch_size): |
| batch = codes[i:i+batch_size] |
| scores = predict_vulnerability_batch(batch) |
| |
| for j, score in enumerate(scores): |
| results.append({ |
| "score": score |
| }) |
| |
| return jsonify({"results": results}) |
| |
| except Exception as e: |
| return jsonify({"error": f"Batch prediction error: {str(e)}"}), 500 |
|
|
|
|
|
|
| def predict_vulnerability(code): |
| dynamic_length = min(max(len(code.split()) * 2, 128), 512) |
| |
| inputs = tokenizer( |
| code, |
| truncation=True, |
| padding='max_length', |
| max_length=dynamic_length, |
| return_tensors='pt' |
| ) |
| |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| |
| with torch.no_grad(): |
| with torch.cuda.amp.autocast() if device.type == 'cuda' else torch.no_grad(): |
| outputs = model(**inputs) |
| |
| if hasattr(outputs, 'logits'): |
| score = torch.sigmoid(outputs.logits).cpu().item() |
| else: |
| score = torch.sigmoid(outputs[0]).cpu().item() |
| |
| return round(score, 4) |
|
|
| def predict_vulnerability_batch(codes): |
| max_len = max([len(code.split()) * 2 for code in codes]) |
| dynamic_length = min(max(max_len, 128), 512) |
| |
| inputs = tokenizer( |
| codes, |
| truncation=True, |
| padding='max_length', |
| max_length=dynamic_length, |
| return_tensors='pt' |
| ) |
| |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| |
| with torch.no_grad(): |
| with torch.cuda.amp.autocast() if device.type == 'cuda' else torch.no_grad(): |
| outputs = model(**inputs) |
| |
| if hasattr(outputs, 'logits'): |
| scores = torch.sigmoid(outputs.logits).cpu().numpy() |
| else: |
| scores = torch.sigmoid(outputs[0]).cpu().numpy() |
| |
| return [round(float(score), 4) for score in scores.flatten()] |
|
|
|
|
| @app.route("/health", methods=['GET']) |
| def health_check(): |
| return jsonify({ |
| "status": "healthy", |
| "model_loaded": model is not None, |
| "tokenizer_loaded": tokenizer is not None, |
| "device": str(device) if device else "unknown" |
| }) |
|
|
| if __name__ == "__main__": |
| app.run(host="0.0.0.0", port=7860, debug=False, threaded=True) |