import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import joblib import os # Load tokenizer tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base") # Load label encoder le = joblib.load("label_encoder.pkl") # Load model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = AutoModelForSequenceClassification.from_pretrained( "microsoft/codebert-base", num_labels=7 ) model.load_state_dict(torch.load("best_model.pt", map_location=device)) model.to(device) model.eval() # Complexity descriptions DESCRIPTIONS = { "constant": ("O(1)", "⚔ Constant Time", "Executes in the same time regardless of input size. Very fast!"), "linear": ("O(n)", "šŸ“ˆ Linear Time", "Execution time grows linearly with input size."), "logn": ("O(log n)", "šŸ” Logarithmic Time", "Very efficient! Common in binary search algorithms."), "nlogn": ("O(n log n)", "āš™ļø Linearithmic Time", "Common in efficient sorting algorithms like merge sort."), "quadratic": ("O(n²)", "🐢 Quadratic Time", "Execution time grows quadratically. Common in nested loops."), "cubic": ("O(n³)", "šŸ¦• Cubic Time", "Triple nested loops. Avoid for large inputs."), "np": ("O(2ⁿ)", "šŸ’€ Exponential Time", "NP-Hard complexity. Only feasible for very small inputs."), } def predict(code): if not code.strip(): return "āš ļø Please paste some code first!", "", "" inputs = tokenizer( code, truncation=True, max_length=512, padding='max_length', return_tensors='pt' ) input_ids = inputs['input_ids'].to(device) attention_mask = inputs['attention_mask'].to(device) with torch.no_grad(): outputs = model(input_ids=input_ids, attention_mask=attention_mask) pred = torch.argmax(outputs.logits, dim=1).item() label = le.inverse_transform([pred])[0] notation, title, description = DESCRIPTIONS.get(label, (label, label, "")) return notation, title, description # Custom CSS css = """ @import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;700&family=Syne:wght@400;700;800&display=swap'); * { box-sizing: border-box; } body, .gradio-container { background: #0a0a0f !important; font-family: 'Syne', sans-serif !important; } .gradio-container { max-width: 900px !important; margin: 0 auto !important; } #header { text-align: center; padding: 40px 20px 20px; } #header h1 { font-size: 2.8em; font-weight: 800; background: linear-gradient(135deg, #00ff88, #00cfff); -webkit-background-clip: text; -webkit-text-fill-color: transparent; margin-bottom: 8px; letter-spacing: -1px; } #header p { color: #888; font-size: 1em; font-family: 'JetBrains Mono', monospace; } .gr-textbox textarea { background: #111118 !important; border: 1px solid #222 !important; color: #e0e0e0 !important; font-family: 'JetBrains Mono', monospace !important; font-size: 0.85em !important; border-radius: 12px !important; padding: 16px !important; } .gr-button-primary { background: linear-gradient(135deg, #00ff88, #00cfff) !important; color: #000 !important; font-weight: 700 !important; font-family: 'Syne', sans-serif !important; border: none !important; border-radius: 10px !important; font-size: 1em !important; letter-spacing: 0.5px !important; } .gr-button-primary:hover { opacity: 0.9 !important; transform: translateY(-1px) !important; } .result-box { background: #111118; border: 1px solid #222; border-radius: 12px; padding: 20px; color: #e0e0e0; } label { color: #666 !important; font-family: 'JetBrains Mono', monospace !important; font-size: 0.75em !important; letter-spacing: 1px !important; text-transform: uppercase !important; } .gr-textbox { border-radius: 12px !important; } """ # Examples examples = [ ["def get_first(arr):\n return arr[0]"], ["def linear_search(arr, target):\n for i in range(len(arr)):\n if arr[i] == target:\n return i\n return -1"], ["def binary_search(arr, target):\n left, right = 0, len(arr) - 1\n while left <= right:\n mid = (left + right) // 2\n if arr[mid] == target:\n return mid\n elif arr[mid] < target:\n left = mid + 1\n else:\n right = mid - 1\n return -1"], ["def bubble_sort(arr):\n n = len(arr)\n for i in range(n):\n for j in range(0, n-i-1):\n if arr[j] > arr[j+1]:\n arr[j], arr[j+1] = arr[j+1], arr[j]"], ] with gr.Blocks(css=css, title="Code Complexity Predictor") as demo: gr.HTML(""" """) with gr.Row(): with gr.Column(scale=3): code_input = gr.Textbox( label="YOUR CODE", placeholder="# Paste your Python or Java code here...", lines=14, max_lines=20 ) predict_btn = gr.Button("⚔ Analyze Complexity", variant="primary") with gr.Column(scale=2): notation_out = gr.Textbox(label="BIG-O NOTATION", interactive=False) title_out = gr.Textbox(label="COMPLEXITY CLASS", interactive=False) desc_out = gr.Textbox(label="EXPLANATION", interactive=False, lines=3) gr.Examples( examples=examples, inputs=code_input, label="Try these examples" ) predict_btn.click( fn=predict, inputs=code_input, outputs=[notation_out, title_out, desc_out] ) demo.launch()