Spaces:
Sleeping
Sleeping
| import warnings | |
| warnings.filterwarnings('ignore') | |
| import streamlit as st | |
| import numpy as np | |
| import joblib | |
| import torch | |
| import torch.nn as nn | |
| import ast | |
| from transformers import RobertaTokenizer, RobertaModel | |
| # ββ Page config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.set_page_config( | |
| page_title="AI Code Detector", | |
| page_icon="π", | |
| layout="centered" | |
| ) | |
| # ββ Device ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| device = torch.device('cpu') | |
| # ββ CodeBERT Architecture βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class CodeBERTClassifier(nn.Module): | |
| def __init__(self, dropout=0.1): | |
| super(CodeBERTClassifier, self).__init__() | |
| self.codebert = RobertaModel.from_pretrained('microsoft/codebert-base') | |
| self.dropout = nn.Dropout(dropout) | |
| self.classifier = nn.Linear(768, 2) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.codebert( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask | |
| ) | |
| cls_output = outputs.last_hidden_state[:, 0, :] | |
| cls_output = self.dropout(cls_output) | |
| return self.classifier(cls_output) | |
| # ββ Load models (cached so they load only once) βββββββββββββββββββββββββββββββ | |
| def load_models(): | |
| scaler = joblib.load("models/scaler.pkl") | |
| lr_model = joblib.load("models/logistic_regression.pkl") | |
| svm_model = joblib.load("models/svm.pkl") | |
| rf_model = joblib.load("models/random_forest.pkl") | |
| tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base') | |
| print("Loading CodeBERT weights...") | |
| cb_model = CodeBERTClassifier() | |
| state_dict = torch.load( | |
| "models/best_model.pt", | |
| map_location=device, | |
| weights_only=False # required for cross-version compatibility | |
| ) | |
| cb_model.load_state_dict(state_dict, strict=True) | |
| cb_model.eval() | |
| # Sanity check β verify model outputs non-trivial probabilities | |
| with torch.no_grad(): | |
| dummy_ids = torch.zeros(1, 512, dtype=torch.long) | |
| dummy_mask = torch.ones(1, 512, dtype=torch.long) | |
| dummy_out = cb_model(dummy_ids, dummy_mask) | |
| dummy_probs = torch.softmax(dummy_out, dim=1)[0].numpy() | |
| print(f"CodeBERT sanity check β Human: {dummy_probs[0]:.4f}, AI: {dummy_probs[1]:.4f}") | |
| if dummy_probs[0] > 0.9999: | |
| print("WARNING: CodeBERT may not have loaded correctly") | |
| else: | |
| print("CodeBERT loaded correctly") | |
| print("All models ready") | |
| return scaler, lr_model, svm_model, rf_model, tokenizer, cb_model | |
| # ββ Ensemble weights ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _raw = np.array([0.8179**4, 0.8708**4, 0.8860**4, 0.9983**4]) | |
| WEIGHTS = _raw / _raw.sum() | |
| # ββ Feature extraction ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_cyclomatic_complexity(func_node): | |
| count = 1 | |
| for node in ast.walk(func_node): | |
| if isinstance(node, (ast.If, ast.For, ast.While, ast.ExceptHandler)): | |
| count += 1 | |
| elif isinstance(node, ast.BoolOp): | |
| count += len(node.values) - 1 | |
| return count | |
| def get_max_nesting_depth(code): | |
| max_depth = 0 | |
| for line in code.split('\n'): | |
| stripped = line.strip() | |
| if stripped == '' or stripped.startswith('#'): | |
| continue | |
| spaces = len(line) - len(line.lstrip()) | |
| max_depth = max(max_depth, spaces // 4) | |
| return max_depth | |
| def get_variable_stats(func_node): | |
| names = [] | |
| for node in ast.walk(func_node): | |
| if isinstance(node, ast.Assign): | |
| for target in node.targets: | |
| if isinstance(target, ast.Name): | |
| names.append(target.id) | |
| elif isinstance(node, ast.AugAssign): | |
| if isinstance(node.target, ast.Name): | |
| names.append(node.target.id) | |
| elif isinstance(node, ast.AnnAssign): | |
| if isinstance(node.target, ast.Name): | |
| names.append(node.target.id) | |
| unique = len(set(names)) | |
| avg_len = round(np.mean([len(n) for n in names]), 2) if names else 0 | |
| return unique, avg_len | |
| def extract_features(code): | |
| try: | |
| lines = code.split('\n') | |
| total_lines = len(lines) | |
| blank_lines = sum(1 for l in lines if l.strip() == '') | |
| comment_lines = sum(1 for l in lines if l.strip().startswith('#')) | |
| tree = ast.parse(code) | |
| if not tree.body or not isinstance(tree.body[0], ast.FunctionDef): | |
| return None | |
| func = tree.body[0] | |
| has_docstring = 0 | |
| docstring_lines = 0 | |
| if (func.body and | |
| isinstance(func.body[0], ast.Expr) and | |
| isinstance(func.body[0].value, ast.Constant) and | |
| isinstance(func.body[0].value.value, str)): | |
| has_docstring = 1 | |
| docstring_lines = len(func.body[0].value.value.split('\n')) | |
| doc_lines = docstring_lines if has_docstring else 0 | |
| code_lines = max( | |
| total_lines - blank_lines - comment_lines - doc_lines, 1 | |
| ) | |
| non_blank = [l for l in lines if l.strip() != ''] | |
| avg_line_length = round( | |
| np.mean([len(l) for l in non_blank]), 2 | |
| ) if non_blank else 0 | |
| params = func.args.args | |
| has_type_hints = 1 if ( | |
| func.returns is not None or | |
| any(a.annotation is not None for a in params) | |
| ) else 0 | |
| num_returns = sum(1 for n in ast.walk(func) if isinstance(n, ast.Return)) | |
| num_raises = sum(1 for n in ast.walk(func) if isinstance(n, ast.Raise)) | |
| num_assertions = sum(1 for n in ast.walk(func) if isinstance(n, ast.Assert)) | |
| num_loops = sum(1 for n in ast.walk(func) | |
| if isinstance(n, (ast.For, ast.While))) | |
| num_exceptions = sum(1 for n in ast.walk(func) | |
| if isinstance(n, ast.ExceptHandler)) | |
| num_calls = sum(1 for n in ast.walk(func) if isinstance(n, ast.Call)) | |
| uses_list_comp = 1 if any(isinstance(n, ast.ListComp) | |
| for n in ast.walk(func)) else 0 | |
| uses_lambda = 1 if any(isinstance(n, ast.Lambda) | |
| for n in ast.walk(func)) else 0 | |
| uses_with = 1 if any(isinstance(n, ast.With) | |
| for n in ast.walk(func)) else 0 | |
| uses_fstring = 1 if any(isinstance(n, ast.JoinedStr) | |
| for n in ast.walk(func)) else 0 | |
| nested_funcs = [n for n in ast.walk(func) | |
| if isinstance(n, ast.FunctionDef) and n is not func] | |
| has_nested = 1 if nested_funcs else 0 | |
| num_vars, avg_var_len = get_variable_stats(func) | |
| return [ | |
| code_lines, blank_lines, avg_line_length, | |
| get_cyclomatic_complexity(func), num_loops, num_exceptions, | |
| get_max_nesting_depth(code), num_returns, | |
| has_docstring, docstring_lines, comment_lines, | |
| num_vars, avg_var_len, has_type_hints, | |
| num_assertions, num_raises, uses_list_comp, | |
| uses_lambda, uses_fstring, uses_with, | |
| num_calls, has_nested | |
| ] | |
| except Exception: | |
| return None | |
| # ββ Prediction ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict(code, scaler, lr_model, svm_model, rf_model, tokenizer, cb_model): | |
| code = code.strip() | |
| if not code.startswith('def '): | |
| return None, "Input must start with 'def'. Please paste a complete Python function." | |
| try: | |
| tree = ast.parse(code) | |
| except SyntaxError as e: | |
| return None, f"Syntax error: {e}" | |
| if not tree.body or not isinstance(tree.body[0], ast.FunctionDef): | |
| return None, "No function definition found." | |
| features = extract_features(code) | |
| if features is None: | |
| return None, "Could not extract features. Check your input." | |
| features_arr = np.array(features, dtype=float).reshape(1, -1) | |
| features_scaled = scaler.transform(features_arr) | |
| lr_prob = lr_model.predict_proba(features_scaled)[0] | |
| svm_prob = svm_model.predict_proba(features_scaled)[0] | |
| rf_prob = rf_model.predict_proba(features_arr)[0] | |
| lr_pred = int(np.argmax(lr_prob)) | |
| svm_pred = int(np.argmax(svm_prob)) | |
| rf_pred = int(np.argmax(rf_prob)) | |
| encoding = tokenizer( | |
| code, | |
| max_length=512, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ) | |
| with torch.no_grad(): | |
| logits = cb_model( | |
| encoding['input_ids'], | |
| encoding['attention_mask'] | |
| ) | |
| cb_prob = torch.softmax(logits, dim=1)[0].numpy() | |
| cb_pred = int(np.argmax(cb_prob)) | |
| ai_probs = np.array([lr_prob[1], svm_prob[1], rf_prob[1], cb_prob[1]]) | |
| ensemble_prob = float(np.dot(WEIGHTS, ai_probs)) | |
| ensemble_pred = 1 if ensemble_prob >= 0.5 else 0 | |
| results = { | |
| 'ensemble_pred': ensemble_pred, | |
| 'ensemble_prob': ensemble_prob, | |
| 'lr_pred': lr_pred, 'lr_prob': lr_prob[1], | |
| 'svm_pred': svm_pred, 'svm_prob': svm_prob[1], | |
| 'rf_pred': rf_pred, 'rf_prob': rf_prob[1], | |
| 'cb_pred': cb_pred, 'cb_prob': cb_prob[1], | |
| 'features': features, | |
| } | |
| return results, None | |
| # ββ Streamlit UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.title("π AI Code Detector") | |
| st.markdown( | |
| "Paste any standalone Python function to detect whether it was written " | |
| "by a **human** or generated by **AI**." | |
| ) | |
| st.info( | |
| "**4 models with weighted ensemble:** \n" | |
| "π΅ Logistic Regression (17%) | π SVM (22%) | " | |
| "π’ Random Forest (23%) | π΄ CodeBERT (38%)" | |
| ) | |
| # Load models with spinner | |
| with st.spinner("Loading models... (first load takes ~30 seconds)"): | |
| scaler, lr_model, svm_model, rf_model, tokenizer, cb_model = load_models() | |
| st.success("All models loaded and ready.") | |
| # Input | |
| code_input = st.text_area( | |
| "Python Function", | |
| height=300, | |
| placeholder="Paste your Python function here...\n\ndef my_function(x, y):\n result = x + y\n return result", | |
| ) | |
| # Detect button | |
| if st.button("π Detect", type="primary"): | |
| if not code_input or code_input.strip() == '': | |
| st.warning("Please paste a Python function first.") | |
| else: | |
| with st.spinner("Analysing... (CodeBERT may take 15-20 seconds on CPU)"): | |
| results, error = predict( | |
| code_input, | |
| scaler, lr_model, svm_model, | |
| rf_model, tokenizer, cb_model | |
| ) | |
| if error: | |
| st.error(error) | |
| else: | |
| # Verdict | |
| if results['ensemble_pred'] == 1: | |
| prob_pct = results['ensemble_prob'] * 100 | |
| st.error(f"## π€ AI GENERATED β {prob_pct:.1f}% AI probability") | |
| else: | |
| prob_pct = (1 - results['ensemble_prob']) * 100 | |
| st.success(f"## π€ HUMAN WRITTEN β {prob_pct:.1f}% Human probability") | |
| # Individual models | |
| st.markdown("### Individual Model Predictions") | |
| col1, col2, col3, col4 = st.columns(4) | |
| def model_card(col, name, pred, prob): | |
| label = "π€ AI" if pred == 1 else "π€ Human" | |
| col.metric(name, label, f"{prob*100:.1f}% AI") | |
| model_card(col1, "π΅ LR", results['lr_pred'], results['lr_prob']) | |
| model_card(col2, "π SVM", results['svm_pred'], results['svm_prob']) | |
| model_card(col3, "π’ RF", results['rf_pred'], results['rf_prob']) | |
| model_card(col4, "π΄ CodeBERT", results['cb_pred'], results['cb_prob']) | |
| # Ensemble weights | |
| st.markdown("### Ensemble Weights") | |
| weights_data = { | |
| "Model": ["Logistic Regression", "SVM", "Random Forest", "CodeBERT"], | |
| "Weight": ["17.0%", "21.9%", "23.4%", "37.7%"], | |
| "F1 Score": ["0.818", "0.871", "0.886", "0.998"], | |
| } | |
| import pandas as pd | |
| st.table(pd.DataFrame(weights_data)) | |
| # Features | |
| st.markdown("### Key Features Extracted") | |
| f = results['features'] | |
| feat_col1, feat_col2 = st.columns(2) | |
| with feat_col1: | |
| st.markdown(f"- **code_lines:** {f[0]}") | |
| st.markdown(f"- **blank_lines:** {f[1]}") | |
| st.markdown(f"- **avg_line_length:** {f[2]}") | |
| st.markdown(f"- **cyclomatic_complexity:** {f[3]}") | |
| st.markdown(f"- **has_docstring:** {'Yes' if f[8] else 'No'}") | |
| with feat_col2: | |
| st.markdown(f"- **docstring_lines:** {f[9]}") | |
| st.markdown(f"- **num_comments:** {f[10]}") | |
| st.markdown(f"- **num_function_calls:** {f[20]}") | |
| st.markdown(f"- **num_unique_variables:** {f[11]}") | |
| st.markdown(f"- **avg_var_name_length:** {f[12]}") | |
| # Example functions | |
| with st.expander("Show example functions to test"): | |
| st.markdown("**Example 1 β Likely Human Written:**") | |
| st.code('''def calculate_statistics(data): | |
| """Calculate basic statistics for a dataset.""" | |
| if not data: | |
| raise ValueError("Data cannot be empty") | |
| sorted_data = sorted(data) | |
| n = len(sorted_data) | |
| mean = sum(sorted_data) / n | |
| if n % 2 == 0: | |
| median = (sorted_data[n//2 - 1] + sorted_data[n//2]) / 2 | |
| else: | |
| median = sorted_data[n//2] | |
| variance = sum((x - mean) ** 2 for x in sorted_data) / n | |
| return {"mean": round(mean, 4), "median": round(median, 4), | |
| "std": round(variance ** 0.5, 4)}''', language="python") | |
| st.markdown("**Example 2 β Likely AI Generated:**") | |
| st.code('''def add_numbers(a, b): | |
| result = a + b | |
| return result''', language="python") |