raghuram00 commited on
Commit
c4c000a
·
verified ·
1 Parent(s): f15ff8a

Add model and app files

Browse files
Files changed (4) hide show
  1. app.py +186 -0
  2. best_model.pt +3 -0
  3. label_encoder.pkl +3 -0
  4. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import joblib
5
+ import os
6
+
7
+ # Load tokenizer
8
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
9
+
10
+ # Load label encoder
11
+ le = joblib.load("label_encoder.pkl")
12
+
13
+ # Load model
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ model = AutoModelForSequenceClassification.from_pretrained(
16
+ "microsoft/codebert-base",
17
+ num_labels=7
18
+ )
19
+ model.load_state_dict(torch.load("best_model.pt", map_location=device))
20
+ model.to(device)
21
+ model.eval()
22
+
23
+ # Complexity descriptions
24
+ DESCRIPTIONS = {
25
+ "constant": ("O(1)", "⚡ Constant Time", "Executes in the same time regardless of input size. Very fast!"),
26
+ "linear": ("O(n)", "📈 Linear Time", "Execution time grows linearly with input size."),
27
+ "logn": ("O(log n)", "🔍 Logarithmic Time", "Very efficient! Common in binary search algorithms."),
28
+ "nlogn": ("O(n log n)", "⚙️ Linearithmic Time", "Common in efficient sorting algorithms like merge sort."),
29
+ "quadratic": ("O(n²)", "🐢 Quadratic Time", "Execution time grows quadratically. Common in nested loops."),
30
+ "cubic": ("O(n³)", "🦕 Cubic Time", "Triple nested loops. Avoid for large inputs."),
31
+ "np": ("O(2ⁿ)", "💀 Exponential Time", "NP-Hard complexity. Only feasible for very small inputs."),
32
+ }
33
+
34
+ def predict(code):
35
+ if not code.strip():
36
+ return "⚠️ Please paste some code first!", "", ""
37
+
38
+ inputs = tokenizer(
39
+ code,
40
+ truncation=True,
41
+ max_length=512,
42
+ padding='max_length',
43
+ return_tensors='pt'
44
+ )
45
+
46
+ input_ids = inputs['input_ids'].to(device)
47
+ attention_mask = inputs['attention_mask'].to(device)
48
+
49
+ with torch.no_grad():
50
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
51
+ pred = torch.argmax(outputs.logits, dim=1).item()
52
+
53
+ label = le.inverse_transform([pred])[0]
54
+ notation, title, description = DESCRIPTIONS.get(label, (label, label, ""))
55
+
56
+ return notation, title, description
57
+
58
+
59
+ # Custom CSS
60
+ css = """
61
+ @import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;700&family=Syne:wght@400;700;800&display=swap');
62
+
63
+ * { box-sizing: border-box; }
64
+
65
+ body, .gradio-container {
66
+ background: #0a0a0f !important;
67
+ font-family: 'Syne', sans-serif !important;
68
+ }
69
+
70
+ .gradio-container {
71
+ max-width: 900px !important;
72
+ margin: 0 auto !important;
73
+ }
74
+
75
+ #header {
76
+ text-align: center;
77
+ padding: 40px 20px 20px;
78
+ }
79
+
80
+ #header h1 {
81
+ font-size: 2.8em;
82
+ font-weight: 800;
83
+ background: linear-gradient(135deg, #00ff88, #00cfff);
84
+ -webkit-background-clip: text;
85
+ -webkit-text-fill-color: transparent;
86
+ margin-bottom: 8px;
87
+ letter-spacing: -1px;
88
+ }
89
+
90
+ #header p {
91
+ color: #888;
92
+ font-size: 1em;
93
+ font-family: 'JetBrains Mono', monospace;
94
+ }
95
+
96
+ .gr-textbox textarea {
97
+ background: #111118 !important;
98
+ border: 1px solid #222 !important;
99
+ color: #e0e0e0 !important;
100
+ font-family: 'JetBrains Mono', monospace !important;
101
+ font-size: 0.85em !important;
102
+ border-radius: 12px !important;
103
+ padding: 16px !important;
104
+ }
105
+
106
+ .gr-button-primary {
107
+ background: linear-gradient(135deg, #00ff88, #00cfff) !important;
108
+ color: #000 !important;
109
+ font-weight: 700 !important;
110
+ font-family: 'Syne', sans-serif !important;
111
+ border: none !important;
112
+ border-radius: 10px !important;
113
+ font-size: 1em !important;
114
+ letter-spacing: 0.5px !important;
115
+ }
116
+
117
+ .gr-button-primary:hover {
118
+ opacity: 0.9 !important;
119
+ transform: translateY(-1px) !important;
120
+ }
121
+
122
+ .result-box {
123
+ background: #111118;
124
+ border: 1px solid #222;
125
+ border-radius: 12px;
126
+ padding: 20px;
127
+ color: #e0e0e0;
128
+ }
129
+
130
+ label {
131
+ color: #666 !important;
132
+ font-family: 'JetBrains Mono', monospace !important;
133
+ font-size: 0.75em !important;
134
+ letter-spacing: 1px !important;
135
+ text-transform: uppercase !important;
136
+ }
137
+
138
+ .gr-textbox {
139
+ border-radius: 12px !important;
140
+ }
141
+ """
142
+
143
+ # Examples
144
+ examples = [
145
+ ["def get_first(arr):\n return arr[0]"],
146
+ ["def linear_search(arr, target):\n for i in range(len(arr)):\n if arr[i] == target:\n return i\n return -1"],
147
+ ["def binary_search(arr, target):\n left, right = 0, len(arr) - 1\n while left <= right:\n mid = (left + right) // 2\n if arr[mid] == target:\n return mid\n elif arr[mid] < target:\n left = mid + 1\n else:\n right = mid - 1\n return -1"],
148
+ ["def bubble_sort(arr):\n n = len(arr)\n for i in range(n):\n for j in range(0, n-i-1):\n if arr[j] > arr[j+1]:\n arr[j], arr[j+1] = arr[j+1], arr[j]"],
149
+ ]
150
+
151
+ with gr.Blocks(css=css, title="Code Complexity Predictor") as demo:
152
+ gr.HTML("""
153
+ <div id="header">
154
+ <h1>⚙️ Code Complexity Predictor</h1>
155
+ <p>// powered by CodeBERT — paste your code, get instant Big-O analysis</p>
156
+ </div>
157
+ """)
158
+
159
+ with gr.Row():
160
+ with gr.Column(scale=3):
161
+ code_input = gr.Textbox(
162
+ label="YOUR CODE",
163
+ placeholder="# Paste your Python or Java code here...",
164
+ lines=14,
165
+ max_lines=20
166
+ )
167
+ predict_btn = gr.Button("⚡ Analyze Complexity", variant="primary")
168
+
169
+ with gr.Column(scale=2):
170
+ notation_out = gr.Textbox(label="BIG-O NOTATION", interactive=False)
171
+ title_out = gr.Textbox(label="COMPLEXITY CLASS", interactive=False)
172
+ desc_out = gr.Textbox(label="EXPLANATION", interactive=False, lines=3)
173
+
174
+ gr.Examples(
175
+ examples=examples,
176
+ inputs=code_input,
177
+ label="Try these examples"
178
+ )
179
+
180
+ predict_btn.click(
181
+ fn=predict,
182
+ inputs=code_input,
183
+ outputs=[notation_out, title_out, desc_out]
184
+ )
185
+
186
+ demo.launch()
best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68768fd3e53895c8b107b2b012af92acae29bfd281bf88b1d2427572a02e7b59
3
+ size 498687962
label_encoder.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7594bbda10b0e8aec7a8c30ad8eb5324954b2bbc2b1c60f982817806cef66cad
3
+ size 533
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers==4.40.0
2
+ torch==2.2.0
3
+ gradio==4.44.0
4
+ joblib==1.3.2
5
+ scikit-learn==1.4.0