raghuram00 commited on
Commit
24be017
Β·
0 Parent(s):

Initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ .env
4
+ venv/
5
+ .DS_Store
6
+
7
+ # Large Model Files
8
+ best_model.pt
9
+ label_encoder.pkl
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ # Set work directory
4
+ WORKDIR /app
5
+
6
+ # Install dependencies
7
+ COPY requirements.txt .
8
+ RUN pip install --no-cache-dir -r requirements.txt
9
+
10
+ # Copy the entire project
11
+ COPY . .
12
+
13
+ # Download model (bypasses GitHub LFS issues)
14
+ RUN python download_model.py
15
+
16
+ # Expose port for Render
17
+ EXPOSE 8000
18
+
19
+ # Run FastAPI via Uvicorn
20
+ CMD ["uvicorn", "backend.main:app", "--host", "0.0.0.0", "--port", "8000"]
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Code Complexity Predictor
3
+ emoji: πŸ“‰
4
+ colorFrom: blue
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 6.9.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
Untitled2.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
backend/main.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import joblib
4
+ from fastapi import FastAPI, HTTPException
5
+ from fastapi.staticfiles import StaticFiles
6
+ from pydantic import BaseModel
7
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
+
9
+ # Complexity descriptions
10
+ DESCRIPTIONS = {
11
+ "constant": ("O(1)", "⚑ Constant Time", "Executes in the same time regardless of input size. Very fast!"),
12
+ "linear": ("O(n)", "πŸ“ˆ Linear Time", "Execution time grows linearly with input size."),
13
+ "logn": ("O(log n)", "πŸ” Logarithmic Time", "Very efficient! Common in binary search algorithms."),
14
+ "nlogn": ("O(n log n)", "βš™οΈ Linearithmic Time", "Common in efficient sorting algorithms like merge sort."),
15
+ "quadratic": ("O(n²)", "🐒 Quadratic Time", "Execution time grows quadratically. Common in nested loops."),
16
+ "cubic": ("O(nΒ³)", "πŸ¦• Cubic Time", "Triple nested loops. Avoid for large inputs."),
17
+ "np": ("O(2ⁿ)", "πŸ’€ Exponential Time", "NP-Hard complexity. Only feasible for very small inputs."),
18
+ }
19
+
20
+ app = FastAPI(title="Code Complexity Predictor API")
21
+
22
+ class PredictRequest(BaseModel):
23
+ code: str
24
+
25
+ # Global state
26
+ model = None
27
+ tokenizer = None
28
+ le = None
29
+ device = None
30
+
31
+ @app.on_event("startup")
32
+ def load_resources():
33
+ global model, tokenizer, le, device
34
+ print("Loading resources...")
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+
37
+ # Load tokenizer
38
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
39
+
40
+ # Load label encoder
41
+ if os.path.exists("label_encoder.pkl"):
42
+ le = joblib.load("label_encoder.pkl")
43
+ else:
44
+ print("WARNING: label_encoder.pkl not found!")
45
+
46
+ # Load model
47
+ model = AutoModelForSequenceClassification.from_pretrained("microsoft/codebert-base", num_labels=7)
48
+ if os.path.exists("best_model.pt"):
49
+ model.load_state_dict(torch.load("best_model.pt", map_location=device))
50
+ else:
51
+ print("WARNING: best_model.pt not found!")
52
+
53
+ model.to(device)
54
+ model.eval()
55
+ print("Resources loaded successfully!")
56
+
57
+ @app.post("/api/predict")
58
+ def predict_complexity(request: PredictRequest):
59
+ code = request.code
60
+ if not code.strip():
61
+ raise HTTPException(status_code=400, detail="Code cannot be empty")
62
+
63
+ try:
64
+ inputs = tokenizer(code, truncation=True, max_length=512, padding='max_length', return_tensors='pt')
65
+ input_ids = inputs['input_ids'].to(device)
66
+ attention_mask = inputs['attention_mask'].to(device)
67
+
68
+ with torch.no_grad():
69
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
70
+ pred = torch.argmax(outputs.logits, dim=1).item()
71
+
72
+ label = le.inverse_transform([pred])[0]
73
+ notation, title, description = DESCRIPTIONS.get(label, (label, label, ""))
74
+
75
+ return {
76
+ "notation": notation,
77
+ "title": title,
78
+ "description": description
79
+ }
80
+ except Exception as e:
81
+ raise HTTPException(status_code=500, detail=str(e))
82
+
83
+ # Mount frontend
84
+ app.mount("/", StaticFiles(directory="frontend", html=True), name="frontend")
download_model.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import urllib.request
2
+ import os
3
+
4
+ def download_file(url, target_path):
5
+ print(f"Downloading {target_path} from {url}...")
6
+ req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})
7
+ with urllib.request.urlopen(req) as response, open(target_path, 'wb') as out_file:
8
+ out_file.write(response.read())
9
+ print(f"βœ… Successfully downloaded {target_path} (Size: {os.path.getsize(target_path)/(1024*1024):.2f} MB)")
10
+
11
+ url_model = "https://huggingface.co/spaces/raghuram00/code-complexity-predictor/resolve/main/best_model.pt"
12
+ url_le = "https://huggingface.co/spaces/raghuram00/code-complexity-predictor/resolve/main/label_encoder.pkl"
13
+
14
+ if not os.path.exists("best_model.pt"):
15
+ download_file(url_model, "best_model.pt")
16
+
17
+ if not os.path.exists("label_encoder.pkl"):
18
+ download_file(url_le, "label_encoder.pkl")
frontend/index.html ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Code Complexity Predictor</title>
7
+ <!-- Modern Fonts -->
8
+ <link rel="preconnect" href="https://fonts.googleapis.com">
9
+ <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
10
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&family=JetBrains+Mono:wght@400;700&family=Outfit:wght@600;800&display=swap" rel="stylesheet">
11
+
12
+ <!-- Prism.js for code syntax highlighting -->
13
+ <link href="https://cdnjs.cloudflare.com/ajax/libs/prism/1.29.0/themes/prism-tomorrow.min.css" rel="stylesheet" />
14
+
15
+ <!-- Custom Styles -->
16
+ <link rel="stylesheet" href="styles.css">
17
+ </head>
18
+ <body>
19
+ <div class="background-effects">
20
+ <div class="glow-orb orb-1"></div>
21
+ <div class="glow-orb orb-2"></div>
22
+ </div>
23
+
24
+ <main class="container">
25
+ <header>
26
+ <div class="logo-wrapper">
27
+ <svg width="40" height="40" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
28
+ <path d="M10 20L14 4M18 8L22 12L18 16M6 16L2 12L6 8" stroke="url(#paint0_linear)" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
29
+ <defs>
30
+ <linearGradient id="paint0_linear" x1="2" y1="12" x2="22" y2="12" gradientUnits="userSpaceOnUse">
31
+ <stop stop-color="#00F0FF" />
32
+ <stop offset="1" stop-color="#8000FF" />
33
+ </linearGradient>
34
+ </defs>
35
+ </svg>
36
+ <h1>Complexity <span>Predictor</span></h1>
37
+ </div>
38
+ <p class="subtitle">// AI-powered algorithmic Big-O analysis using CodeBERT</p>
39
+ </header>
40
+
41
+ <div class="app-grid">
42
+ <!-- Left Side: Code Input -->
43
+ <section class="editor-section">
44
+ <div class="card glass-panel">
45
+ <div class="card-header">
46
+ <span class="dot red"></span>
47
+ <span class="dot yellow"></span>
48
+ <span class="dot green"></span>
49
+ <span class="filename">algorithm.py</span>
50
+ </div>
51
+ <div class="editor-wrapper">
52
+ <!-- We use a textarea over a pre code block to allow editing,
53
+ but styling them to overlap for syntax highlighting -->
54
+ <textarea id="codeInput" placeholder="# Paste your Python or Java code here...
55
+ def example(arr):
56
+ for item in arr:
57
+ print(item)"></textarea>
58
+ </div>
59
+ </div>
60
+ <button id="analyzeBtn" class="primary-btn">
61
+ <span class="btn-text">⚑ Analyze Complexity</span>
62
+ <span class="loader hidden"></span>
63
+ </button>
64
+
65
+ <div class="examples-section">
66
+ <h3>Try these examples:</h3>
67
+ <div class="example-chips">
68
+ <button class="chip" data-example='def get_first(arr):\n return arr[0]'>O(1) Constant</button>
69
+ <button class="chip" data-example='def linear_search(arr, target):\n for i in range(len(arr)):\n if arr[i] == target:\n return i\n return -1'>O(n) Linear</button>
70
+ <button class="chip" data-example='def bubble_sort(arr):\n n = len(arr)\n for i in range(n):\n for j in range(0, n-i-1):\n if arr[j] > arr[j+1]:\n arr[j], arr[j+1] = arr[j+1], arr[j]'>O(nΒ²) Quadratic</button>
71
+ </div>
72
+ </div>
73
+ </section>
74
+
75
+ <!-- Right Side: Results -->
76
+ <section class="results-section">
77
+ <div class="result-card glass-panel">
78
+ <h2 class="result-label">Big-O Notation</h2>
79
+ <div class="result-value glow-text" id="resNotation">O(?)</div>
80
+ </div>
81
+
82
+ <div class="result-card glass-panel">
83
+ <h2 class="result-label">Complexity Class</h2>
84
+ <div class="result-value" id="resTitle">-</div>
85
+ </div>
86
+
87
+ <div class="result-card glass-panel desc-card">
88
+ <h2 class="result-label">Explanation</h2>
89
+ <p class="result-desc" id="resDesc">Awaiting code input to predict complexity.</p>
90
+ </div>
91
+ </section>
92
+ </div>
93
+ </main>
94
+
95
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.29.0/prism.min.js"></script>
96
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.29.0/components/prism-python.min.js"></script>
97
+ <script src="script.js"></script>
98
+ </body>
99
+ </html>
frontend/script.js ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ document.addEventListener('DOMContentLoaded', () => {
2
+ const codeInput = document.getElementById('codeInput');
3
+ const analyzeBtn = document.getElementById('analyzeBtn');
4
+ const btnText = analyzeBtn.querySelector('.btn-text');
5
+ const loader = analyzeBtn.querySelector('.loader');
6
+
7
+ // Result elements
8
+ const resNotation = document.getElementById('resNotation');
9
+ const resTitle = document.getElementById('resTitle');
10
+ const resDesc = document.getElementById('resDesc');
11
+
12
+ // Example chips
13
+ document.querySelectorAll('.chip').forEach(chip => {
14
+ chip.addEventListener('click', () => {
15
+ // Replace \n from dataset with actual newlines
16
+ codeInput.value = chip.dataset.example.replace(/\\n/g, '\n');
17
+ triggerAnalysis();
18
+ });
19
+ });
20
+
21
+ analyzeBtn.addEventListener('click', triggerAnalysis);
22
+
23
+ // Also support Cmd/Ctrl + Enter to trigger
24
+ codeInput.addEventListener('keydown', (e) => {
25
+ if ((e.ctrlKey || e.metaKey) && e.key === 'Enter') {
26
+ triggerAnalysis();
27
+ }
28
+ });
29
+
30
+ async function triggerAnalysis() {
31
+ const code = codeInput.value.trim();
32
+
33
+ if (!code) {
34
+ resNotation.innerHTML = "O(?)";
35
+ resTitle.innerHTML = "Error";
36
+ resDesc.innerHTML = "⚠️ Please paste some code before analyzing!";
37
+ return;
38
+ }
39
+
40
+ // UI Loading state
41
+ analyzeBtn.disabled = true;
42
+ btnText.innerHTML = "Analyzing structure...";
43
+ loader.classList.remove('hidden');
44
+ resNotation.style.opacity = '0.5';
45
+ resTitle.style.opacity = '0.5';
46
+ resDesc.style.opacity = '0.5';
47
+
48
+ try {
49
+ const response = await fetch('/api/predict', {
50
+ method: 'POST',
51
+ headers: {
52
+ 'Content-Type': 'application/json'
53
+ },
54
+ body: JSON.stringify({ code: code })
55
+ });
56
+
57
+ if (!response.ok) {
58
+ throw new Error(`Server error: ${response.status}`);
59
+ }
60
+
61
+ const data = await response.json();
62
+
63
+ // Render Results
64
+ resNotation.innerHTML = data.notation;
65
+ resTitle.innerHTML = data.title;
66
+ resDesc.innerHTML = data.description;
67
+
68
+ } catch (error) {
69
+ console.error('Analysis failed:', error);
70
+ resNotation.innerHTML = "O(?)";
71
+ resTitle.innerHTML = "Analysis Failed";
72
+ resDesc.innerHTML = "An error occurred while connecting to the AI model. Ensure the backend is running.";
73
+ } finally {
74
+ // Restore UI state
75
+ analyzeBtn.disabled = false;
76
+ btnText.innerHTML = "⚑ Analyze Complexity";
77
+ loader.classList.add('hidden');
78
+
79
+ // Fade results back in
80
+ resNotation.style.opacity = '1';
81
+ resTitle.style.opacity = '1';
82
+ resDesc.style.opacity = '1';
83
+
84
+ // Add a little pop animation to the results
85
+ document.querySelectorAll('.result-card').forEach(card => {
86
+ card.style.transform = 'scale(1.02)';
87
+ setTimeout(() => card.style.transform = 'scale(1)', 200);
88
+ });
89
+ }
90
+ }
91
+ });
frontend/styles.css ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :root {
2
+ --bg-color: #050508;
3
+ --panel-bg: rgba(22, 22, 30, 0.6);
4
+ --border-color: rgba(255, 255, 255, 0.08);
5
+ --text-main: #f0f0f5;
6
+ --text-dim: #8b8b9f;
7
+ --accent-primary: #00F0FF;
8
+ --accent-secondary: #8000FF;
9
+ --glass-blur: 16px;
10
+ }
11
+
12
+ * {
13
+ margin: 0;
14
+ padding: 0;
15
+ box-sizing: border-box;
16
+ }
17
+
18
+ body {
19
+ background-color: var(--bg-color);
20
+ color: var(--text-main);
21
+ font-family: 'Inter', sans-serif;
22
+ min-height: 100vh;
23
+ display: flex;
24
+ justify-content: center;
25
+ position: relative;
26
+ overflow-x: hidden;
27
+ }
28
+
29
+ /* Animated Background Orbs */
30
+ .background-effects {
31
+ position: fixed;
32
+ top: 0;
33
+ left: 0;
34
+ width: 100vw;
35
+ height: 100vh;
36
+ z-index: -1;
37
+ overflow: hidden;
38
+ pointer-events: none;
39
+ }
40
+
41
+ .glow-orb {
42
+ position: absolute;
43
+ border-radius: 50%;
44
+ filter: blur(120px);
45
+ opacity: 0.4;
46
+ animation: float 20s infinite alternate ease-in-out;
47
+ }
48
+
49
+ .orb-1 {
50
+ width: 600px;
51
+ height: 600px;
52
+ background: var(--accent-secondary);
53
+ top: -200px;
54
+ right: -100px;
55
+ }
56
+
57
+ .orb-2 {
58
+ width: 500px;
59
+ height: 500px;
60
+ background: var(--accent-primary);
61
+ bottom: -200px;
62
+ left: -100px;
63
+ animation-delay: -10s;
64
+ }
65
+
66
+ @keyframes float {
67
+ 0% { transform: translate(0, 0) scale(1); }
68
+ 100% { transform: translate(-100px, 100px) scale(1.2); }
69
+ }
70
+
71
+ .container {
72
+ width: 100%;
73
+ max-width: 1200px;
74
+ padding: 3rem 2rem;
75
+ display: flex;
76
+ flex-direction: column;
77
+ gap: 3rem;
78
+ }
79
+
80
+ /* Header */
81
+ header {
82
+ text-align: center;
83
+ display: flex;
84
+ flex-direction: column;
85
+ align-items: center;
86
+ gap: 1rem;
87
+ animation: fadeDown 0.8s ease-out;
88
+ }
89
+
90
+ .logo-wrapper {
91
+ display: flex;
92
+ align-items: center;
93
+ gap: 1rem;
94
+ }
95
+
96
+ h1 {
97
+ font-family: 'Outfit', sans-serif;
98
+ font-size: 3.5rem;
99
+ font-weight: 800;
100
+ letter-spacing: -1px;
101
+ }
102
+
103
+ h1 span {
104
+ background: linear-gradient(135deg, var(--accent-primary), var(--accent-secondary));
105
+ -webkit-background-clip: text;
106
+ -webkit-text-fill-color: transparent;
107
+ }
108
+
109
+ .subtitle {
110
+ font-family: 'JetBrains Mono', monospace;
111
+ color: var(--text-dim);
112
+ font-size: 0.9rem;
113
+ }
114
+
115
+ /* Main Grid */
116
+ .app-grid {
117
+ display: grid;
118
+ grid-template-columns: 1.5fr 1fr;
119
+ gap: 2rem;
120
+ align-items: start;
121
+ animation: fadeUp 1s ease-out;
122
+ }
123
+
124
+ /* Glassmorphism Panels */
125
+ .glass-panel {
126
+ background: var(--panel-bg);
127
+ backdrop-filter: blur(var(--glass-blur));
128
+ -webkit-backdrop-filter: blur(var(--glass-blur));
129
+ border: 1px solid var(--border-color);
130
+ border-radius: 20px;
131
+ box-shadow: 0 8px 32px 0 rgba(0, 0, 0, 0.3);
132
+ }
133
+
134
+ /* Editor Section */
135
+ .card-header {
136
+ padding: 1rem;
137
+ border-bottom: 1px solid var(--border-color);
138
+ display: flex;
139
+ align-items: center;
140
+ gap: 0.5rem;
141
+ background: rgba(0,0,0,0.2);
142
+ border-radius: 20px 20px 0 0;
143
+ }
144
+
145
+ .dot {
146
+ width: 12px;
147
+ height: 12px;
148
+ border-radius: 50%;
149
+ }
150
+ .red { background: #ff5f56; }
151
+ .yellow { background: #ffbd2e; }
152
+ .green { background: #27c93f; }
153
+
154
+ .filename {
155
+ margin-left: 1rem;
156
+ font-family: 'JetBrains Mono', monospace;
157
+ font-size: 0.8rem;
158
+ color: var(--text-dim);
159
+ }
160
+
161
+ .editor-wrapper {
162
+ position: relative;
163
+ padding: 1rem;
164
+ }
165
+
166
+ textarea {
167
+ width: 100%;
168
+ min-height: 350px;
169
+ background: transparent;
170
+ border: none;
171
+ color: var(--text-main);
172
+ font-family: 'JetBrains Mono', monospace;
173
+ font-size: 0.9rem;
174
+ line-height: 1.6;
175
+ resize: vertical;
176
+ outline: none;
177
+ }
178
+
179
+ textarea::placeholder {
180
+ color: rgba(255, 255, 255, 0.2);
181
+ }
182
+
183
+ /* Button */
184
+ .primary-btn {
185
+ width: 100%;
186
+ margin-top: 1.5rem;
187
+ padding: 1.2rem;
188
+ border: none;
189
+ border-radius: 12px;
190
+ background: linear-gradient(135deg, var(--accent-primary), var(--accent-secondary));
191
+ color: #fff;
192
+ font-family: 'Outfit', sans-serif;
193
+ font-size: 1.2rem;
194
+ font-weight: 600;
195
+ cursor: pointer;
196
+ transition: all 0.3s ease;
197
+ display: flex;
198
+ justify-content: center;
199
+ align-items: center;
200
+ gap: 1rem;
201
+ position: relative;
202
+ overflow: hidden;
203
+ }
204
+
205
+ .primary-btn::before {
206
+ content: '';
207
+ position: absolute;
208
+ top: 0; left: -100%;
209
+ width: 100%; height: 100%;
210
+ background: linear-gradient(90deg, transparent, rgba(255,255,255,0.2), transparent);
211
+ transition: 0.5s;
212
+ }
213
+
214
+ .primary-btn:hover::before {
215
+ left: 100%;
216
+ }
217
+
218
+ .primary-btn:hover {
219
+ transform: translateY(-2px);
220
+ box-shadow: 0 10px 20px rgba(128, 0, 255, 0.3);
221
+ }
222
+
223
+ .primary-btn:active {
224
+ transform: translateY(0);
225
+ }
226
+
227
+ /* Loader */
228
+ .loader {
229
+ width: 20px;
230
+ height: 20px;
231
+ border: 3px solid rgba(255,255,255,0.3);
232
+ border-radius: 50%;
233
+ border-top-color: #fff;
234
+ animation: spin 1s ease-in-out infinite;
235
+ }
236
+ .hidden { display: none; }
237
+
238
+ @keyframes spin {
239
+ to { transform: rotate(360deg); }
240
+ }
241
+
242
+ /* Examples */
243
+ .examples-section {
244
+ margin-top: 2rem;
245
+ }
246
+ .examples-section h3 {
247
+ font-size: 0.9rem;
248
+ color: var(--text-dim);
249
+ margin-bottom: 1rem;
250
+ font-weight: 500;
251
+ }
252
+ .example-chips {
253
+ display: flex;
254
+ gap: 0.8rem;
255
+ flex-wrap: wrap;
256
+ }
257
+ .chip {
258
+ background: rgba(255,255,255,0.05);
259
+ border: 1px solid rgba(255,255,255,0.1);
260
+ color: var(--text-main);
261
+ padding: 0.5rem 1rem;
262
+ border-radius: 20px;
263
+ font-family: 'JetBrains Mono', monospace;
264
+ font-size: 0.8rem;
265
+ cursor: pointer;
266
+ transition: all 0.2s;
267
+ }
268
+ .chip:hover {
269
+ background: rgba(255,255,255,0.1);
270
+ border-color: var(--accent-primary);
271
+ }
272
+
273
+ /* Results Section */
274
+ .results-section {
275
+ display: flex;
276
+ flex-direction: column;
277
+ gap: 1.5rem;
278
+ }
279
+
280
+ .result-card {
281
+ padding: 2rem;
282
+ transition: transform 0.3s;
283
+ }
284
+
285
+ .result-card:hover {
286
+ transform: translateY(-2px);
287
+ }
288
+
289
+ .result-label {
290
+ font-size: 0.8rem;
291
+ text-transform: uppercase;
292
+ letter-spacing: 2px;
293
+ color: var(--text-dim);
294
+ margin-bottom: 1rem;
295
+ }
296
+
297
+ .result-value {
298
+ font-family: 'Outfit', sans-serif;
299
+ font-size: 2rem;
300
+ font-weight: 800;
301
+ }
302
+
303
+ .glow-text {
304
+ background: linear-gradient(135deg, var(--accent-primary), #fff);
305
+ -webkit-background-clip: text;
306
+ -webkit-text-fill-color: transparent;
307
+ font-size: 3rem;
308
+ }
309
+
310
+ .result-desc {
311
+ color: #a0a0b5;
312
+ line-height: 1.6;
313
+ font-size: 1.1rem;
314
+ }
315
+
316
+ /* Animations */
317
+ @keyframes fadeDown {
318
+ from { opacity: 0; transform: translateY(-20px); }
319
+ to { opacity: 1; transform: translateY(0); }
320
+ }
321
+
322
+ @keyframes fadeUp {
323
+ from { opacity: 0; transform: translateY(20px); }
324
+ to { opacity: 1; transform: translateY(0); }
325
+ }
326
+
327
+ /* Responsive */
328
+ @media (max-width: 900px) {
329
+ .app-grid {
330
+ grid-template-columns: 1fr;
331
+ }
332
+ h1 { font-size: 2.5rem; }
333
+ }
improved_training.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==============================================================================
2
+ # πŸš€ IMPROVED CODE COMPLEXITY PREDICTOR TRAINING SCRIPT πŸš€
3
+ # ==============================================================================
4
+ # Run this entire script in Google Colab (either pasted into a cell or via script)
5
+
6
+ # 1. Install dependencies
7
+ # !pip install -q transformers datasets torch scikit-learn
8
+
9
+ import pandas as pd
10
+ import torch
11
+ import torch.nn as nn
12
+ from datasets import load_dataset
13
+ from sklearn.preprocessing import LabelEncoder
14
+ from sklearn.model_selection import train_test_split
15
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
16
+ from transformers import get_linear_schedule_with_warmup
17
+ from torch.utils.data import Dataset, DataLoader
18
+ from torch.optim import AdamW
19
+ from tqdm import tqdm
20
+ import os
21
+ import shutil
22
+
23
+ # ------------------------------------------------------------------------------
24
+ # βš™οΈ CONFIGURATION & HYPERPARAMETERS
25
+ # ------------------------------------------------------------------------------
26
+ MODEL_NAME = "microsoft/graphcodebert-base" # πŸ”₯ Upgraded to GraphCodeBERT
27
+ MAX_LEN = 512 # Max token length
28
+ BATCH_SIZE = 16 # Training batch size
29
+ EPOCHS = 15 # πŸ”₯ Increased from 3 to 15
30
+ LEARNING_RATE = 3e-5 # Optimized initial learning rate
31
+ WEIGHT_DECAY = 0.05 # πŸ”₯ Increased Regularization
32
+ PATIENCE = 3 # πŸ”₯ Early Stopping patience
33
+ SAVE_PATH = "best_model.pt"
34
+
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ print(f"πŸ–₯️ Using device: {device}")
37
+
38
+ # ------------------------------------------------------------------------------
39
+ # πŸ“Š DATA PREPARATION
40
+ # ------------------------------------------------------------------------------
41
+ print("\n[1/5] Loading Dataset...")
42
+ dataset = load_dataset("codeparrot/codecomplex")
43
+ df = pd.DataFrame(dataset['train'])
44
+
45
+ # Encode labels
46
+ le = LabelEncoder()
47
+ df['label'] = le.fit_transform(df['complexity'])
48
+
49
+ # Save Label Encoder for Inference
50
+ import joblib
51
+ joblib.dump(le, "label_encoder.pkl")
52
+
53
+ # Train/Test Split (stratified)
54
+ train_df, test_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['label'])
55
+
56
+ # Calculate Class Weights to handle imbalance right from the start
57
+ class_counts = train_df['label'].value_counts().sort_index().values
58
+ total_samples = sum(class_counts)
59
+ class_weights = torch.tensor([total_samples / c for c in class_counts], dtype=torch.float).to(device)
60
+
61
+ print(f"βœ… Loaded {len(train_df)} training and {len(test_df)} testing samples.")
62
+
63
+ # ------------------------------------------------------------------------------
64
+ # 🧠 TOKENIZATION & DATASETS
65
+ # ------------------------------------------------------------------------------
66
+ print(f"\n[2/5] Initializing Tokenizer ({MODEL_NAME})...")
67
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
68
+
69
+ class CodeDataset(Dataset):
70
+ def __init__(self, dataframe, tokenizer, max_length=MAX_LEN):
71
+ self.data = dataframe
72
+ self.tokenizer = tokenizer
73
+ self.max_length = max_length
74
+
75
+ def __len__(self):
76
+ return len(self.data)
77
+
78
+ def __getitem__(self, idx):
79
+ code = str(self.data.iloc[idx]['src'])
80
+ label = int(self.data.iloc[idx]['label'])
81
+
82
+ encoding = self.tokenizer(
83
+ code,
84
+ truncation=True,
85
+ max_length=self.max_length,
86
+ padding='max_length',
87
+ return_tensors='pt'
88
+ )
89
+
90
+ return {
91
+ 'input_ids': encoding['input_ids'].squeeze(),
92
+ 'attention_mask': encoding['attention_mask'].squeeze(),
93
+ 'label': torch.tensor(label, dtype=torch.long)
94
+ }
95
+
96
+ train_dataset = CodeDataset(train_df.reset_index(drop=True), tokenizer)
97
+ test_dataset = CodeDataset(test_df.reset_index(drop=True), tokenizer)
98
+
99
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
100
+ test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
101
+
102
+ # ------------------------------------------------------------------------------
103
+ # πŸ—οΈ MODEL INITIALIZATION
104
+ # ------------------------------------------------------------------------------
105
+ print(f"\n[3/5] Loading Model ({MODEL_NAME})...")
106
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=7)
107
+ model = model.to(device)
108
+
109
+ # Optimizer with Weight Decay
110
+ optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
111
+
112
+ # Scheduler
113
+ total_steps = len(train_loader) * EPOCHS
114
+ scheduler = get_linear_schedule_with_warmup(
115
+ optimizer,
116
+ num_warmup_steps=int(total_steps * 0.1), # 10% warmup
117
+ num_training_steps=total_steps
118
+ )
119
+
120
+ # Loss function with balanced classes
121
+ criterion = nn.CrossEntropyLoss(weight=class_weights)
122
+
123
+ # ------------------------------------------------------------------------------
124
+ # πŸƒ TRAINING & EVALUATION FUNCTIONS
125
+ # ------------------------------------------------------------------------------
126
+ def train_epoch(model, loader, optimizer, scheduler, criterion, device):
127
+ model.train()
128
+ total_loss, correct, total = 0, 0, 0
129
+
130
+ for batch in tqdm(loader, desc="Training", leave=False):
131
+ input_ids = batch['input_ids'].to(device)
132
+ attention_mask = batch['attention_mask'].to(device)
133
+ labels = batch['label'].to(device)
134
+
135
+ optimizer.zero_grad()
136
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
137
+ loss = criterion(outputs.logits, labels)
138
+
139
+ loss.backward()
140
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
141
+ optimizer.step()
142
+ scheduler.step()
143
+
144
+ total_loss += loss.item()
145
+ preds = torch.argmax(outputs.logits, dim=1)
146
+ correct += (preds == labels).sum().item()
147
+ total += labels.size(0)
148
+
149
+ return total_loss / len(loader), correct / total
150
+
151
+ def evaluate(model, loader, device):
152
+ model.eval()
153
+ correct, total = 0, 0
154
+
155
+ with torch.no_grad():
156
+ for batch in tqdm(loader, desc="Evaluating", leave=False):
157
+ input_ids = batch['input_ids'].to(device)
158
+ attention_mask = batch['attention_mask'].to(device)
159
+ labels = batch['label'].to(device)
160
+
161
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
162
+ preds = torch.argmax(outputs.logits, dim=1)
163
+ correct += (preds == labels).sum().item()
164
+ total += labels.size(0)
165
+
166
+ return correct / total
167
+
168
+ # ------------------------------------------------------------------------------
169
+ # πŸ₯‡ MAIN TRAINING LOOP WITH EARLY STOPPING
170
+ # ------------------------------------------------------------------------------
171
+ print("\n[4/5] Starting Training Loop...")
172
+ best_accuracy = 0
173
+ epochs_no_improve = 0
174
+
175
+ for epoch in range(EPOCHS):
176
+ print(f"\nπŸ”„ Epoch {epoch+1}/{EPOCHS}")
177
+
178
+ train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, criterion, device)
179
+ test_acc = evaluate(model, test_loader, device)
180
+
181
+ print(f"πŸ“ˆ Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}% | Test Acc: {test_acc*100:.2f}%")
182
+
183
+ # Early Stopping Logic
184
+ if test_acc > best_accuracy:
185
+ best_accuracy = test_acc
186
+ epochs_no_improve = 0
187
+ torch.save(model.state_dict(), SAVE_PATH)
188
+ print(f"⭐ NEW BEST MODEL SAVED! Accuracy: {best_accuracy*100:.2f}%")
189
+ else:
190
+ epochs_no_improve += 1
191
+ print(f"⚠️ No improvement for {epochs_no_improve} epochs.")
192
+
193
+ if epochs_no_improve >= PATIENCE:
194
+ print(f"\n⏹️ EARLY STOPPING TRIGGERED! Test accuracy hasn't improved in {PATIENCE} epochs.")
195
+ break
196
+
197
+ # ------------------------------------------------------------------------------
198
+ # πŸ’Ύ EXPORTING TO DRIVE
199
+ # ------------------------------------------------------------------------------
200
+ print("\n[5/5] Finalizing...")
201
+ try:
202
+ from google.colab import drive
203
+ drive.mount('/content/drive', force_remount=True)
204
+ shutil.copy(SAVE_PATH, f"/content/drive/MyDrive/{SAVE_PATH}")
205
+ shutil.copy("label_encoder.pkl", "/content/drive/MyDrive/label_encoder.pkl")
206
+ print("βœ… Files successfully backed up to Google Drive!")
207
+ except ImportError:
208
+ print("Not running in Colab - skipping Drive export.")
render.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ services:
2
+ - type: web
3
+ name: code-complexity-predictor
4
+ env: docker
5
+ instanceCT: 1
6
+ plan: free
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ transformers
4
+ torch
5
+ scikit-learn
6
+ pandas
7
+ joblib
8
+ python-multipart
train_extracted.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !pip install transformers datasets torch scikit-learn
2
+
3
+ # --- CELL ---
4
+
5
+ from datasets import load_dataset
6
+
7
+ dataset = load_dataset("codeparrot/codecomplex")
8
+ print(dataset)
9
+ print(dataset['train'][0])
10
+
11
+ # --- CELL ---
12
+
13
+ import pandas as pd
14
+
15
+ df = pd.DataFrame(dataset['train'])
16
+
17
+ # Check complexity labels
18
+ print("Complexity classes:")
19
+ print(df['complexity'].value_counts())
20
+
21
+ print("\nLanguages:")
22
+ print(df['from'].value_counts())
23
+
24
+ print("\nTotal samples:", len(df))
25
+
26
+ # --- CELL ---
27
+
28
+ from sklearn.preprocessing import LabelEncoder
29
+ from sklearn.model_selection import train_test_split
30
+
31
+ # Encode labels
32
+ le = LabelEncoder()
33
+ df['label'] = le.fit_transform(df['complexity'])
34
+
35
+ print("Label mapping:")
36
+ for i, cls in enumerate(le.classes_):
37
+ print(f" {cls} β†’ {i}")
38
+
39
+ # Split data
40
+ train_df, test_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['label'])
41
+
42
+ print(f"\nTrain size: {len(train_df)}")
43
+ print(f"Test size: {len(test_df)}")
44
+
45
+ # --- CELL ---
46
+
47
+ from transformers import AutoTokenizer
48
+
49
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
50
+
51
+ print("βœ… CodeBERT tokenizer loaded!")
52
+
53
+ # Test it
54
+ sample = df['src'][0][:200]
55
+ tokens = tokenizer(sample, truncation=True, max_length=512, return_tensors="pt")
56
+ print("Sample token shape:", tokens['input_ids'].shape)
57
+
58
+ # --- CELL ---
59
+
60
+ import torch
61
+ from torch.utils.data import Dataset
62
+
63
+ class CodeDataset(Dataset):
64
+ def __init__(self, dataframe, tokenizer, max_length=512):
65
+ self.data = dataframe
66
+ self.tokenizer = tokenizer
67
+ self.max_length = max_length
68
+
69
+ def __len__(self):
70
+ return len(self.data)
71
+
72
+ def __getitem__(self, idx):
73
+ code = str(self.data.iloc[idx]['src'])
74
+ label = int(self.data.iloc[idx]['label'])
75
+
76
+ encoding = self.tokenizer(
77
+ code,
78
+ truncation=True,
79
+ max_length=self.max_length,
80
+ padding='max_length',
81
+ return_tensors='pt'
82
+ )
83
+
84
+ return {
85
+ 'input_ids': encoding['input_ids'].squeeze(),
86
+ 'attention_mask': encoding['attention_mask'].squeeze(),
87
+ 'label': torch.tensor(label, dtype=torch.long)
88
+ }
89
+
90
+ # Create datasets
91
+ train_dataset = CodeDataset(train_df.reset_index(drop=True), tokenizer)
92
+ test_dataset = CodeDataset(test_df.reset_index(drop=True), tokenizer)
93
+
94
+ print(f"βœ… Train dataset: {len(train_dataset)} samples")
95
+ print(f"βœ… Test dataset: {len(test_dataset)} samples")
96
+
97
+ # --- CELL ---
98
+
99
+ from transformers import AutoModelForSequenceClassification
100
+ import torch
101
+
102
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
103
+ print(f"Using device: {device}")
104
+
105
+ model = AutoModelForSequenceClassification.from_pretrained(
106
+ "microsoft/codebert-base",
107
+ num_labels=7
108
+ )
109
+
110
+ model = model.to(device)
111
+ print("βœ… CodeBERT model loaded!")
112
+ print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
113
+
114
+ # --- CELL ---
115
+
116
+ from torch.utils.data import DataLoader
117
+ from torch.optim import AdamW
118
+ from transformers import get_linear_schedule_with_warmup
119
+
120
+ # DataLoaders
121
+ train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
122
+ test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
123
+
124
+ # Optimizer
125
+ optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
126
+
127
+ # Scheduler
128
+ total_steps = len(train_loader) * 3 # 3 epochs
129
+ scheduler = get_linear_schedule_with_warmup(
130
+ optimizer,
131
+ num_warmup_steps=total_steps // 10,
132
+ num_training_steps=total_steps
133
+ )
134
+
135
+ print(f"βœ… DataLoaders ready!")
136
+ print(f"Total training steps: {total_steps}")
137
+ print(f"Steps per epoch: {len(train_loader)}")
138
+
139
+ # --- CELL ---
140
+
141
+ from tqdm import tqdm
142
+
143
+ def train_epoch(model, loader, optimizer, scheduler, device):
144
+ model.train()
145
+ total_loss = 0
146
+ correct = 0
147
+ total = 0
148
+
149
+ for batch in tqdm(loader, desc="Training"):
150
+ input_ids = batch['input_ids'].to(device)
151
+ attention_mask = batch['attention_mask'].to(device)
152
+ labels = batch['label'].to(device)
153
+
154
+ optimizer.zero_grad()
155
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
156
+ loss = outputs.loss
157
+ logits = outputs.logits
158
+
159
+ loss.backward()
160
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
161
+ optimizer.step()
162
+ scheduler.step()
163
+
164
+ total_loss += loss.item()
165
+ preds = torch.argmax(logits, dim=1)
166
+ correct += (preds == labels).sum().item()
167
+ total += labels.size(0)
168
+
169
+ return total_loss / len(loader), correct / total
170
+
171
+
172
+ def evaluate(model, loader, device):
173
+ model.eval()
174
+ correct = 0
175
+ total = 0
176
+
177
+ with torch.no_grad():
178
+ for batch in tqdm(loader, desc="Evaluating"):
179
+ input_ids = batch['input_ids'].to(device)
180
+ attention_mask = batch['attention_mask'].to(device)
181
+ labels = batch['label'].to(device)
182
+
183
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
184
+ preds = torch.argmax(outputs.logits, dim=1)
185
+ correct += (preds == labels).sum().item()
186
+ total += labels.size(0)
187
+
188
+ return correct / total
189
+
190
+
191
+ # Train for 3 epochs
192
+ best_accuracy = 0
193
+
194
+ for epoch in range(3):
195
+ print(f"\nπŸ”„ Epoch {epoch+1}/3")
196
+ train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, device)
197
+ test_acc = evaluate(model, test_loader, device)
198
+
199
+ print(f"Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}% | Test Acc: {test_acc*100:.2f}%")
200
+
201
+ if test_acc > best_accuracy:
202
+ best_accuracy = test_acc
203
+ torch.save(model.state_dict(), "best_model.pt")
204
+ print(f"βœ… Best model saved! Accuracy: {best_accuracy*100:.2f}%")
205
+
206
+ # --- CELL ---
207
+
208
+ # Train 2 more epochs
209
+ for epoch in range(2):
210
+ print(f"\nπŸ”„ Epoch {epoch+4}/5")
211
+ train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, device)
212
+ test_acc = evaluate(model, test_loader, device)
213
+
214
+ print(f"Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}% | Test Acc: {test_acc*100:.2f}%")
215
+
216
+ if test_acc > best_accuracy:
217
+ best_accuracy = test_acc
218
+ torch.save(model.state_dict(), "best_model.pt")
219
+ print(f"βœ… Best model saved! Accuracy: {best_accuracy*100:.2f}%")
220
+
221
+ # --- CELL ---
222
+
223
+ from google.colab import drive
224
+ drive.mount('/content/drive')
225
+
226
+ # --- CELL ---
227
+
228
+ import shutil
229
+
230
+ # Copy files to Google Drive
231
+ shutil.copy("best_model.pt", "/content/drive/MyDrive/best_model.pt")
232
+ shutil.copy("label_encoder.pkl", "/content/drive/MyDrive/label_encoder.pkl")
233
+
234
+ print("βœ… Files saved to Google Drive!")
235
+
236
+ # --- CELL ---
237
+
238
+ # Test the model directly in Colab
239
+ test_codes = [
240
+ "public int findMax(int[] arr) { int max = arr[0]; for (int i = 1; i < arr.length; i++) { if (arr[i] > max) max = arr[i]; } return max; }",
241
+ "return arr[0];",
242
+ "for(int i=0;i<n;i++) for(int j=0;j<n;j++) sum+=arr[i][j];",
243
+ ]
244
+
245
+ for code in test_codes:
246
+ inputs = tokenizer(code, truncation=True, max_length=512, padding='max_length', return_tensors='pt')
247
+ input_ids = inputs['input_ids'].to(device)
248
+ attention_mask = inputs['attention_mask'].to(device)
249
+
250
+ with torch.no_grad():
251
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
252
+ pred = torch.argmax(outputs.logits, dim=1).item()
253
+
254
+ print(f"Code: {code[:50]}...")
255
+ print(f"Predicted: {le.inverse_transform([pred])[0]}\n")
256
+
257
+ # --- CELL ---
258
+
259
+ import torch.nn as nn
260
+
261
+ # Count class frequencies
262
+ class_counts = df['label'].value_counts().sort_index().values
263
+ total = sum(class_counts)
264
+ class_weights = torch.tensor([total/c for c in class_counts], dtype=torch.float).to(device)
265
+
266
+ print("Class weights:", class_weights)
267
+
268
+ # New training loop with weighted loss
269
+ def train_epoch_weighted(model, loader, optimizer, scheduler, device, weights):
270
+ model.train()
271
+ total_loss = 0
272
+ correct = 0
273
+ total = 0
274
+ criterion = nn.CrossEntropyLoss(weight=weights)
275
+
276
+ for batch in tqdm(loader, desc="Training"):
277
+ input_ids = batch['input_ids'].to(device)
278
+ attention_mask = batch['attention_mask'].to(device)
279
+ labels = batch['label'].to(device)
280
+
281
+ optimizer.zero_grad()
282
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
283
+ loss = criterion(outputs.logits, labels)
284
+
285
+ loss.backward()
286
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
287
+ optimizer.step()
288
+ scheduler.step()
289
+
290
+ total_loss += loss.item()
291
+ preds = torch.argmax(outputs.logits, dim=1)
292
+ correct += (preds == labels).sum().item()
293
+ total += labels.size(0)
294
+
295
+ return total_loss / len(loader), correct / total
296
+
297
+ # Retrain with weights
298
+ optimizer3 = AdamW(model.parameters(), lr=5e-6)
299
+ scheduler3 = get_linear_schedule_with_warmup(optimizer3, num_warmup_steps=30, num_training_steps=len(train_loader)*3)
300
+
301
+ for epoch in range(3):
302
+ print(f"\nπŸ”„ Epoch {epoch+1}/3")
303
+ train_loss, train_acc = train_epoch_weighted(model, train_loader, optimizer3, scheduler3, device, class_weights)
304
+ test_acc = evaluate(model, test_loader, device)
305
+ print(f"Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}% | Test Acc: {test_acc*100:.2f}%")
306
+ if test_acc > best_accuracy:
307
+ best_accuracy = test_acc
308
+ torch.save(model.state_dict(), "best_model.pt")
309
+ print(f"βœ… Best model saved! Accuracy: {best_accuracy*100:.2f}%")
310
+
311
+ # --- CELL ---
312
+
313
+ import shutil
314
+ shutil.copy("best_model.pt", "/content/drive/MyDrive/best_model.pt")
315
+ print("βœ… Saved to Google Drive!")