Spaces:
No application file
No application file
Commit Β·
24be017
0
Parent(s):
Initial commit
Browse files- .gitattributes +35 -0
- .gitignore +9 -0
- Dockerfile +20 -0
- README.md +13 -0
- Untitled2.ipynb +0 -0
- backend/main.py +84 -0
- download_model.py +18 -0
- frontend/index.html +99 -0
- frontend/script.js +91 -0
- frontend/styles.css +333 -0
- improved_training.py +208 -0
- render.yaml +6 -0
- requirements.txt +8 -0
- train_extracted.py +315 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
.env
|
| 4 |
+
venv/
|
| 5 |
+
.DS_Store
|
| 6 |
+
|
| 7 |
+
# Large Model Files
|
| 8 |
+
best_model.pt
|
| 9 |
+
label_encoder.pkl
|
Dockerfile
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.9-slim
|
| 2 |
+
|
| 3 |
+
# Set work directory
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
|
| 6 |
+
# Install dependencies
|
| 7 |
+
COPY requirements.txt .
|
| 8 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 9 |
+
|
| 10 |
+
# Copy the entire project
|
| 11 |
+
COPY . .
|
| 12 |
+
|
| 13 |
+
# Download model (bypasses GitHub LFS issues)
|
| 14 |
+
RUN python download_model.py
|
| 15 |
+
|
| 16 |
+
# Expose port for Render
|
| 17 |
+
EXPOSE 8000
|
| 18 |
+
|
| 19 |
+
# Run FastAPI via Uvicorn
|
| 20 |
+
CMD ["uvicorn", "backend.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Code Complexity Predictor
|
| 3 |
+
emoji: π
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: yellow
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 6.9.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: apache-2.0
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
Untitled2.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
backend/main.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import joblib
|
| 4 |
+
from fastapi import FastAPI, HTTPException
|
| 5 |
+
from fastapi.staticfiles import StaticFiles
|
| 6 |
+
from pydantic import BaseModel
|
| 7 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 8 |
+
|
| 9 |
+
# Complexity descriptions
|
| 10 |
+
DESCRIPTIONS = {
|
| 11 |
+
"constant": ("O(1)", "β‘ Constant Time", "Executes in the same time regardless of input size. Very fast!"),
|
| 12 |
+
"linear": ("O(n)", "π Linear Time", "Execution time grows linearly with input size."),
|
| 13 |
+
"logn": ("O(log n)", "π Logarithmic Time", "Very efficient! Common in binary search algorithms."),
|
| 14 |
+
"nlogn": ("O(n log n)", "βοΈ Linearithmic Time", "Common in efficient sorting algorithms like merge sort."),
|
| 15 |
+
"quadratic": ("O(nΒ²)", "π’ Quadratic Time", "Execution time grows quadratically. Common in nested loops."),
|
| 16 |
+
"cubic": ("O(nΒ³)", "π¦ Cubic Time", "Triple nested loops. Avoid for large inputs."),
|
| 17 |
+
"np": ("O(2βΏ)", "π Exponential Time", "NP-Hard complexity. Only feasible for very small inputs."),
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
app = FastAPI(title="Code Complexity Predictor API")
|
| 21 |
+
|
| 22 |
+
class PredictRequest(BaseModel):
|
| 23 |
+
code: str
|
| 24 |
+
|
| 25 |
+
# Global state
|
| 26 |
+
model = None
|
| 27 |
+
tokenizer = None
|
| 28 |
+
le = None
|
| 29 |
+
device = None
|
| 30 |
+
|
| 31 |
+
@app.on_event("startup")
|
| 32 |
+
def load_resources():
|
| 33 |
+
global model, tokenizer, le, device
|
| 34 |
+
print("Loading resources...")
|
| 35 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 36 |
+
|
| 37 |
+
# Load tokenizer
|
| 38 |
+
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
|
| 39 |
+
|
| 40 |
+
# Load label encoder
|
| 41 |
+
if os.path.exists("label_encoder.pkl"):
|
| 42 |
+
le = joblib.load("label_encoder.pkl")
|
| 43 |
+
else:
|
| 44 |
+
print("WARNING: label_encoder.pkl not found!")
|
| 45 |
+
|
| 46 |
+
# Load model
|
| 47 |
+
model = AutoModelForSequenceClassification.from_pretrained("microsoft/codebert-base", num_labels=7)
|
| 48 |
+
if os.path.exists("best_model.pt"):
|
| 49 |
+
model.load_state_dict(torch.load("best_model.pt", map_location=device))
|
| 50 |
+
else:
|
| 51 |
+
print("WARNING: best_model.pt not found!")
|
| 52 |
+
|
| 53 |
+
model.to(device)
|
| 54 |
+
model.eval()
|
| 55 |
+
print("Resources loaded successfully!")
|
| 56 |
+
|
| 57 |
+
@app.post("/api/predict")
|
| 58 |
+
def predict_complexity(request: PredictRequest):
|
| 59 |
+
code = request.code
|
| 60 |
+
if not code.strip():
|
| 61 |
+
raise HTTPException(status_code=400, detail="Code cannot be empty")
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
inputs = tokenizer(code, truncation=True, max_length=512, padding='max_length', return_tensors='pt')
|
| 65 |
+
input_ids = inputs['input_ids'].to(device)
|
| 66 |
+
attention_mask = inputs['attention_mask'].to(device)
|
| 67 |
+
|
| 68 |
+
with torch.no_grad():
|
| 69 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
| 70 |
+
pred = torch.argmax(outputs.logits, dim=1).item()
|
| 71 |
+
|
| 72 |
+
label = le.inverse_transform([pred])[0]
|
| 73 |
+
notation, title, description = DESCRIPTIONS.get(label, (label, label, ""))
|
| 74 |
+
|
| 75 |
+
return {
|
| 76 |
+
"notation": notation,
|
| 77 |
+
"title": title,
|
| 78 |
+
"description": description
|
| 79 |
+
}
|
| 80 |
+
except Exception as e:
|
| 81 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 82 |
+
|
| 83 |
+
# Mount frontend
|
| 84 |
+
app.mount("/", StaticFiles(directory="frontend", html=True), name="frontend")
|
download_model.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import urllib.request
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
def download_file(url, target_path):
|
| 5 |
+
print(f"Downloading {target_path} from {url}...")
|
| 6 |
+
req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})
|
| 7 |
+
with urllib.request.urlopen(req) as response, open(target_path, 'wb') as out_file:
|
| 8 |
+
out_file.write(response.read())
|
| 9 |
+
print(f"β
Successfully downloaded {target_path} (Size: {os.path.getsize(target_path)/(1024*1024):.2f} MB)")
|
| 10 |
+
|
| 11 |
+
url_model = "https://huggingface.co/spaces/raghuram00/code-complexity-predictor/resolve/main/best_model.pt"
|
| 12 |
+
url_le = "https://huggingface.co/spaces/raghuram00/code-complexity-predictor/resolve/main/label_encoder.pkl"
|
| 13 |
+
|
| 14 |
+
if not os.path.exists("best_model.pt"):
|
| 15 |
+
download_file(url_model, "best_model.pt")
|
| 16 |
+
|
| 17 |
+
if not os.path.exists("label_encoder.pkl"):
|
| 18 |
+
download_file(url_le, "label_encoder.pkl")
|
frontend/index.html
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>Code Complexity Predictor</title>
|
| 7 |
+
<!-- Modern Fonts -->
|
| 8 |
+
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 9 |
+
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
| 10 |
+
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&family=JetBrains+Mono:wght@400;700&family=Outfit:wght@600;800&display=swap" rel="stylesheet">
|
| 11 |
+
|
| 12 |
+
<!-- Prism.js for code syntax highlighting -->
|
| 13 |
+
<link href="https://cdnjs.cloudflare.com/ajax/libs/prism/1.29.0/themes/prism-tomorrow.min.css" rel="stylesheet" />
|
| 14 |
+
|
| 15 |
+
<!-- Custom Styles -->
|
| 16 |
+
<link rel="stylesheet" href="styles.css">
|
| 17 |
+
</head>
|
| 18 |
+
<body>
|
| 19 |
+
<div class="background-effects">
|
| 20 |
+
<div class="glow-orb orb-1"></div>
|
| 21 |
+
<div class="glow-orb orb-2"></div>
|
| 22 |
+
</div>
|
| 23 |
+
|
| 24 |
+
<main class="container">
|
| 25 |
+
<header>
|
| 26 |
+
<div class="logo-wrapper">
|
| 27 |
+
<svg width="40" height="40" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
| 28 |
+
<path d="M10 20L14 4M18 8L22 12L18 16M6 16L2 12L6 8" stroke="url(#paint0_linear)" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
| 29 |
+
<defs>
|
| 30 |
+
<linearGradient id="paint0_linear" x1="2" y1="12" x2="22" y2="12" gradientUnits="userSpaceOnUse">
|
| 31 |
+
<stop stop-color="#00F0FF" />
|
| 32 |
+
<stop offset="1" stop-color="#8000FF" />
|
| 33 |
+
</linearGradient>
|
| 34 |
+
</defs>
|
| 35 |
+
</svg>
|
| 36 |
+
<h1>Complexity <span>Predictor</span></h1>
|
| 37 |
+
</div>
|
| 38 |
+
<p class="subtitle">// AI-powered algorithmic Big-O analysis using CodeBERT</p>
|
| 39 |
+
</header>
|
| 40 |
+
|
| 41 |
+
<div class="app-grid">
|
| 42 |
+
<!-- Left Side: Code Input -->
|
| 43 |
+
<section class="editor-section">
|
| 44 |
+
<div class="card glass-panel">
|
| 45 |
+
<div class="card-header">
|
| 46 |
+
<span class="dot red"></span>
|
| 47 |
+
<span class="dot yellow"></span>
|
| 48 |
+
<span class="dot green"></span>
|
| 49 |
+
<span class="filename">algorithm.py</span>
|
| 50 |
+
</div>
|
| 51 |
+
<div class="editor-wrapper">
|
| 52 |
+
<!-- We use a textarea over a pre code block to allow editing,
|
| 53 |
+
but styling them to overlap for syntax highlighting -->
|
| 54 |
+
<textarea id="codeInput" placeholder="# Paste your Python or Java code here...
|
| 55 |
+
def example(arr):
|
| 56 |
+
for item in arr:
|
| 57 |
+
print(item)"></textarea>
|
| 58 |
+
</div>
|
| 59 |
+
</div>
|
| 60 |
+
<button id="analyzeBtn" class="primary-btn">
|
| 61 |
+
<span class="btn-text">β‘ Analyze Complexity</span>
|
| 62 |
+
<span class="loader hidden"></span>
|
| 63 |
+
</button>
|
| 64 |
+
|
| 65 |
+
<div class="examples-section">
|
| 66 |
+
<h3>Try these examples:</h3>
|
| 67 |
+
<div class="example-chips">
|
| 68 |
+
<button class="chip" data-example='def get_first(arr):\n return arr[0]'>O(1) Constant</button>
|
| 69 |
+
<button class="chip" data-example='def linear_search(arr, target):\n for i in range(len(arr)):\n if arr[i] == target:\n return i\n return -1'>O(n) Linear</button>
|
| 70 |
+
<button class="chip" data-example='def bubble_sort(arr):\n n = len(arr)\n for i in range(n):\n for j in range(0, n-i-1):\n if arr[j] > arr[j+1]:\n arr[j], arr[j+1] = arr[j+1], arr[j]'>O(nΒ²) Quadratic</button>
|
| 71 |
+
</div>
|
| 72 |
+
</div>
|
| 73 |
+
</section>
|
| 74 |
+
|
| 75 |
+
<!-- Right Side: Results -->
|
| 76 |
+
<section class="results-section">
|
| 77 |
+
<div class="result-card glass-panel">
|
| 78 |
+
<h2 class="result-label">Big-O Notation</h2>
|
| 79 |
+
<div class="result-value glow-text" id="resNotation">O(?)</div>
|
| 80 |
+
</div>
|
| 81 |
+
|
| 82 |
+
<div class="result-card glass-panel">
|
| 83 |
+
<h2 class="result-label">Complexity Class</h2>
|
| 84 |
+
<div class="result-value" id="resTitle">-</div>
|
| 85 |
+
</div>
|
| 86 |
+
|
| 87 |
+
<div class="result-card glass-panel desc-card">
|
| 88 |
+
<h2 class="result-label">Explanation</h2>
|
| 89 |
+
<p class="result-desc" id="resDesc">Awaiting code input to predict complexity.</p>
|
| 90 |
+
</div>
|
| 91 |
+
</section>
|
| 92 |
+
</div>
|
| 93 |
+
</main>
|
| 94 |
+
|
| 95 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.29.0/prism.min.js"></script>
|
| 96 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.29.0/components/prism-python.min.js"></script>
|
| 97 |
+
<script src="script.js"></script>
|
| 98 |
+
</body>
|
| 99 |
+
</html>
|
frontend/script.js
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
document.addEventListener('DOMContentLoaded', () => {
|
| 2 |
+
const codeInput = document.getElementById('codeInput');
|
| 3 |
+
const analyzeBtn = document.getElementById('analyzeBtn');
|
| 4 |
+
const btnText = analyzeBtn.querySelector('.btn-text');
|
| 5 |
+
const loader = analyzeBtn.querySelector('.loader');
|
| 6 |
+
|
| 7 |
+
// Result elements
|
| 8 |
+
const resNotation = document.getElementById('resNotation');
|
| 9 |
+
const resTitle = document.getElementById('resTitle');
|
| 10 |
+
const resDesc = document.getElementById('resDesc');
|
| 11 |
+
|
| 12 |
+
// Example chips
|
| 13 |
+
document.querySelectorAll('.chip').forEach(chip => {
|
| 14 |
+
chip.addEventListener('click', () => {
|
| 15 |
+
// Replace \n from dataset with actual newlines
|
| 16 |
+
codeInput.value = chip.dataset.example.replace(/\\n/g, '\n');
|
| 17 |
+
triggerAnalysis();
|
| 18 |
+
});
|
| 19 |
+
});
|
| 20 |
+
|
| 21 |
+
analyzeBtn.addEventListener('click', triggerAnalysis);
|
| 22 |
+
|
| 23 |
+
// Also support Cmd/Ctrl + Enter to trigger
|
| 24 |
+
codeInput.addEventListener('keydown', (e) => {
|
| 25 |
+
if ((e.ctrlKey || e.metaKey) && e.key === 'Enter') {
|
| 26 |
+
triggerAnalysis();
|
| 27 |
+
}
|
| 28 |
+
});
|
| 29 |
+
|
| 30 |
+
async function triggerAnalysis() {
|
| 31 |
+
const code = codeInput.value.trim();
|
| 32 |
+
|
| 33 |
+
if (!code) {
|
| 34 |
+
resNotation.innerHTML = "O(?)";
|
| 35 |
+
resTitle.innerHTML = "Error";
|
| 36 |
+
resDesc.innerHTML = "β οΈ Please paste some code before analyzing!";
|
| 37 |
+
return;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
// UI Loading state
|
| 41 |
+
analyzeBtn.disabled = true;
|
| 42 |
+
btnText.innerHTML = "Analyzing structure...";
|
| 43 |
+
loader.classList.remove('hidden');
|
| 44 |
+
resNotation.style.opacity = '0.5';
|
| 45 |
+
resTitle.style.opacity = '0.5';
|
| 46 |
+
resDesc.style.opacity = '0.5';
|
| 47 |
+
|
| 48 |
+
try {
|
| 49 |
+
const response = await fetch('/api/predict', {
|
| 50 |
+
method: 'POST',
|
| 51 |
+
headers: {
|
| 52 |
+
'Content-Type': 'application/json'
|
| 53 |
+
},
|
| 54 |
+
body: JSON.stringify({ code: code })
|
| 55 |
+
});
|
| 56 |
+
|
| 57 |
+
if (!response.ok) {
|
| 58 |
+
throw new Error(`Server error: ${response.status}`);
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
const data = await response.json();
|
| 62 |
+
|
| 63 |
+
// Render Results
|
| 64 |
+
resNotation.innerHTML = data.notation;
|
| 65 |
+
resTitle.innerHTML = data.title;
|
| 66 |
+
resDesc.innerHTML = data.description;
|
| 67 |
+
|
| 68 |
+
} catch (error) {
|
| 69 |
+
console.error('Analysis failed:', error);
|
| 70 |
+
resNotation.innerHTML = "O(?)";
|
| 71 |
+
resTitle.innerHTML = "Analysis Failed";
|
| 72 |
+
resDesc.innerHTML = "An error occurred while connecting to the AI model. Ensure the backend is running.";
|
| 73 |
+
} finally {
|
| 74 |
+
// Restore UI state
|
| 75 |
+
analyzeBtn.disabled = false;
|
| 76 |
+
btnText.innerHTML = "β‘ Analyze Complexity";
|
| 77 |
+
loader.classList.add('hidden');
|
| 78 |
+
|
| 79 |
+
// Fade results back in
|
| 80 |
+
resNotation.style.opacity = '1';
|
| 81 |
+
resTitle.style.opacity = '1';
|
| 82 |
+
resDesc.style.opacity = '1';
|
| 83 |
+
|
| 84 |
+
// Add a little pop animation to the results
|
| 85 |
+
document.querySelectorAll('.result-card').forEach(card => {
|
| 86 |
+
card.style.transform = 'scale(1.02)';
|
| 87 |
+
setTimeout(() => card.style.transform = 'scale(1)', 200);
|
| 88 |
+
});
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
});
|
frontend/styles.css
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
:root {
|
| 2 |
+
--bg-color: #050508;
|
| 3 |
+
--panel-bg: rgba(22, 22, 30, 0.6);
|
| 4 |
+
--border-color: rgba(255, 255, 255, 0.08);
|
| 5 |
+
--text-main: #f0f0f5;
|
| 6 |
+
--text-dim: #8b8b9f;
|
| 7 |
+
--accent-primary: #00F0FF;
|
| 8 |
+
--accent-secondary: #8000FF;
|
| 9 |
+
--glass-blur: 16px;
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
* {
|
| 13 |
+
margin: 0;
|
| 14 |
+
padding: 0;
|
| 15 |
+
box-sizing: border-box;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
body {
|
| 19 |
+
background-color: var(--bg-color);
|
| 20 |
+
color: var(--text-main);
|
| 21 |
+
font-family: 'Inter', sans-serif;
|
| 22 |
+
min-height: 100vh;
|
| 23 |
+
display: flex;
|
| 24 |
+
justify-content: center;
|
| 25 |
+
position: relative;
|
| 26 |
+
overflow-x: hidden;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
/* Animated Background Orbs */
|
| 30 |
+
.background-effects {
|
| 31 |
+
position: fixed;
|
| 32 |
+
top: 0;
|
| 33 |
+
left: 0;
|
| 34 |
+
width: 100vw;
|
| 35 |
+
height: 100vh;
|
| 36 |
+
z-index: -1;
|
| 37 |
+
overflow: hidden;
|
| 38 |
+
pointer-events: none;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
.glow-orb {
|
| 42 |
+
position: absolute;
|
| 43 |
+
border-radius: 50%;
|
| 44 |
+
filter: blur(120px);
|
| 45 |
+
opacity: 0.4;
|
| 46 |
+
animation: float 20s infinite alternate ease-in-out;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
.orb-1 {
|
| 50 |
+
width: 600px;
|
| 51 |
+
height: 600px;
|
| 52 |
+
background: var(--accent-secondary);
|
| 53 |
+
top: -200px;
|
| 54 |
+
right: -100px;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
.orb-2 {
|
| 58 |
+
width: 500px;
|
| 59 |
+
height: 500px;
|
| 60 |
+
background: var(--accent-primary);
|
| 61 |
+
bottom: -200px;
|
| 62 |
+
left: -100px;
|
| 63 |
+
animation-delay: -10s;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
@keyframes float {
|
| 67 |
+
0% { transform: translate(0, 0) scale(1); }
|
| 68 |
+
100% { transform: translate(-100px, 100px) scale(1.2); }
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
.container {
|
| 72 |
+
width: 100%;
|
| 73 |
+
max-width: 1200px;
|
| 74 |
+
padding: 3rem 2rem;
|
| 75 |
+
display: flex;
|
| 76 |
+
flex-direction: column;
|
| 77 |
+
gap: 3rem;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
/* Header */
|
| 81 |
+
header {
|
| 82 |
+
text-align: center;
|
| 83 |
+
display: flex;
|
| 84 |
+
flex-direction: column;
|
| 85 |
+
align-items: center;
|
| 86 |
+
gap: 1rem;
|
| 87 |
+
animation: fadeDown 0.8s ease-out;
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
.logo-wrapper {
|
| 91 |
+
display: flex;
|
| 92 |
+
align-items: center;
|
| 93 |
+
gap: 1rem;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
h1 {
|
| 97 |
+
font-family: 'Outfit', sans-serif;
|
| 98 |
+
font-size: 3.5rem;
|
| 99 |
+
font-weight: 800;
|
| 100 |
+
letter-spacing: -1px;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
h1 span {
|
| 104 |
+
background: linear-gradient(135deg, var(--accent-primary), var(--accent-secondary));
|
| 105 |
+
-webkit-background-clip: text;
|
| 106 |
+
-webkit-text-fill-color: transparent;
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
.subtitle {
|
| 110 |
+
font-family: 'JetBrains Mono', monospace;
|
| 111 |
+
color: var(--text-dim);
|
| 112 |
+
font-size: 0.9rem;
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
/* Main Grid */
|
| 116 |
+
.app-grid {
|
| 117 |
+
display: grid;
|
| 118 |
+
grid-template-columns: 1.5fr 1fr;
|
| 119 |
+
gap: 2rem;
|
| 120 |
+
align-items: start;
|
| 121 |
+
animation: fadeUp 1s ease-out;
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
/* Glassmorphism Panels */
|
| 125 |
+
.glass-panel {
|
| 126 |
+
background: var(--panel-bg);
|
| 127 |
+
backdrop-filter: blur(var(--glass-blur));
|
| 128 |
+
-webkit-backdrop-filter: blur(var(--glass-blur));
|
| 129 |
+
border: 1px solid var(--border-color);
|
| 130 |
+
border-radius: 20px;
|
| 131 |
+
box-shadow: 0 8px 32px 0 rgba(0, 0, 0, 0.3);
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
/* Editor Section */
|
| 135 |
+
.card-header {
|
| 136 |
+
padding: 1rem;
|
| 137 |
+
border-bottom: 1px solid var(--border-color);
|
| 138 |
+
display: flex;
|
| 139 |
+
align-items: center;
|
| 140 |
+
gap: 0.5rem;
|
| 141 |
+
background: rgba(0,0,0,0.2);
|
| 142 |
+
border-radius: 20px 20px 0 0;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
.dot {
|
| 146 |
+
width: 12px;
|
| 147 |
+
height: 12px;
|
| 148 |
+
border-radius: 50%;
|
| 149 |
+
}
|
| 150 |
+
.red { background: #ff5f56; }
|
| 151 |
+
.yellow { background: #ffbd2e; }
|
| 152 |
+
.green { background: #27c93f; }
|
| 153 |
+
|
| 154 |
+
.filename {
|
| 155 |
+
margin-left: 1rem;
|
| 156 |
+
font-family: 'JetBrains Mono', monospace;
|
| 157 |
+
font-size: 0.8rem;
|
| 158 |
+
color: var(--text-dim);
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
.editor-wrapper {
|
| 162 |
+
position: relative;
|
| 163 |
+
padding: 1rem;
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
textarea {
|
| 167 |
+
width: 100%;
|
| 168 |
+
min-height: 350px;
|
| 169 |
+
background: transparent;
|
| 170 |
+
border: none;
|
| 171 |
+
color: var(--text-main);
|
| 172 |
+
font-family: 'JetBrains Mono', monospace;
|
| 173 |
+
font-size: 0.9rem;
|
| 174 |
+
line-height: 1.6;
|
| 175 |
+
resize: vertical;
|
| 176 |
+
outline: none;
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
textarea::placeholder {
|
| 180 |
+
color: rgba(255, 255, 255, 0.2);
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
/* Button */
|
| 184 |
+
.primary-btn {
|
| 185 |
+
width: 100%;
|
| 186 |
+
margin-top: 1.5rem;
|
| 187 |
+
padding: 1.2rem;
|
| 188 |
+
border: none;
|
| 189 |
+
border-radius: 12px;
|
| 190 |
+
background: linear-gradient(135deg, var(--accent-primary), var(--accent-secondary));
|
| 191 |
+
color: #fff;
|
| 192 |
+
font-family: 'Outfit', sans-serif;
|
| 193 |
+
font-size: 1.2rem;
|
| 194 |
+
font-weight: 600;
|
| 195 |
+
cursor: pointer;
|
| 196 |
+
transition: all 0.3s ease;
|
| 197 |
+
display: flex;
|
| 198 |
+
justify-content: center;
|
| 199 |
+
align-items: center;
|
| 200 |
+
gap: 1rem;
|
| 201 |
+
position: relative;
|
| 202 |
+
overflow: hidden;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
.primary-btn::before {
|
| 206 |
+
content: '';
|
| 207 |
+
position: absolute;
|
| 208 |
+
top: 0; left: -100%;
|
| 209 |
+
width: 100%; height: 100%;
|
| 210 |
+
background: linear-gradient(90deg, transparent, rgba(255,255,255,0.2), transparent);
|
| 211 |
+
transition: 0.5s;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
.primary-btn:hover::before {
|
| 215 |
+
left: 100%;
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
.primary-btn:hover {
|
| 219 |
+
transform: translateY(-2px);
|
| 220 |
+
box-shadow: 0 10px 20px rgba(128, 0, 255, 0.3);
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
.primary-btn:active {
|
| 224 |
+
transform: translateY(0);
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
/* Loader */
|
| 228 |
+
.loader {
|
| 229 |
+
width: 20px;
|
| 230 |
+
height: 20px;
|
| 231 |
+
border: 3px solid rgba(255,255,255,0.3);
|
| 232 |
+
border-radius: 50%;
|
| 233 |
+
border-top-color: #fff;
|
| 234 |
+
animation: spin 1s ease-in-out infinite;
|
| 235 |
+
}
|
| 236 |
+
.hidden { display: none; }
|
| 237 |
+
|
| 238 |
+
@keyframes spin {
|
| 239 |
+
to { transform: rotate(360deg); }
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
/* Examples */
|
| 243 |
+
.examples-section {
|
| 244 |
+
margin-top: 2rem;
|
| 245 |
+
}
|
| 246 |
+
.examples-section h3 {
|
| 247 |
+
font-size: 0.9rem;
|
| 248 |
+
color: var(--text-dim);
|
| 249 |
+
margin-bottom: 1rem;
|
| 250 |
+
font-weight: 500;
|
| 251 |
+
}
|
| 252 |
+
.example-chips {
|
| 253 |
+
display: flex;
|
| 254 |
+
gap: 0.8rem;
|
| 255 |
+
flex-wrap: wrap;
|
| 256 |
+
}
|
| 257 |
+
.chip {
|
| 258 |
+
background: rgba(255,255,255,0.05);
|
| 259 |
+
border: 1px solid rgba(255,255,255,0.1);
|
| 260 |
+
color: var(--text-main);
|
| 261 |
+
padding: 0.5rem 1rem;
|
| 262 |
+
border-radius: 20px;
|
| 263 |
+
font-family: 'JetBrains Mono', monospace;
|
| 264 |
+
font-size: 0.8rem;
|
| 265 |
+
cursor: pointer;
|
| 266 |
+
transition: all 0.2s;
|
| 267 |
+
}
|
| 268 |
+
.chip:hover {
|
| 269 |
+
background: rgba(255,255,255,0.1);
|
| 270 |
+
border-color: var(--accent-primary);
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
/* Results Section */
|
| 274 |
+
.results-section {
|
| 275 |
+
display: flex;
|
| 276 |
+
flex-direction: column;
|
| 277 |
+
gap: 1.5rem;
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
.result-card {
|
| 281 |
+
padding: 2rem;
|
| 282 |
+
transition: transform 0.3s;
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
.result-card:hover {
|
| 286 |
+
transform: translateY(-2px);
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
.result-label {
|
| 290 |
+
font-size: 0.8rem;
|
| 291 |
+
text-transform: uppercase;
|
| 292 |
+
letter-spacing: 2px;
|
| 293 |
+
color: var(--text-dim);
|
| 294 |
+
margin-bottom: 1rem;
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
.result-value {
|
| 298 |
+
font-family: 'Outfit', sans-serif;
|
| 299 |
+
font-size: 2rem;
|
| 300 |
+
font-weight: 800;
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
.glow-text {
|
| 304 |
+
background: linear-gradient(135deg, var(--accent-primary), #fff);
|
| 305 |
+
-webkit-background-clip: text;
|
| 306 |
+
-webkit-text-fill-color: transparent;
|
| 307 |
+
font-size: 3rem;
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
.result-desc {
|
| 311 |
+
color: #a0a0b5;
|
| 312 |
+
line-height: 1.6;
|
| 313 |
+
font-size: 1.1rem;
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
/* Animations */
|
| 317 |
+
@keyframes fadeDown {
|
| 318 |
+
from { opacity: 0; transform: translateY(-20px); }
|
| 319 |
+
to { opacity: 1; transform: translateY(0); }
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
@keyframes fadeUp {
|
| 323 |
+
from { opacity: 0; transform: translateY(20px); }
|
| 324 |
+
to { opacity: 1; transform: translateY(0); }
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
/* Responsive */
|
| 328 |
+
@media (max-width: 900px) {
|
| 329 |
+
.app-grid {
|
| 330 |
+
grid-template-columns: 1fr;
|
| 331 |
+
}
|
| 332 |
+
h1 { font-size: 2.5rem; }
|
| 333 |
+
}
|
improved_training.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ==============================================================================
|
| 2 |
+
# π IMPROVED CODE COMPLEXITY PREDICTOR TRAINING SCRIPT π
|
| 3 |
+
# ==============================================================================
|
| 4 |
+
# Run this entire script in Google Colab (either pasted into a cell or via script)
|
| 5 |
+
|
| 6 |
+
# 1. Install dependencies
|
| 7 |
+
# !pip install -q transformers datasets torch scikit-learn
|
| 8 |
+
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from datasets import load_dataset
|
| 13 |
+
from sklearn.preprocessing import LabelEncoder
|
| 14 |
+
from sklearn.model_selection import train_test_split
|
| 15 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 16 |
+
from transformers import get_linear_schedule_with_warmup
|
| 17 |
+
from torch.utils.data import Dataset, DataLoader
|
| 18 |
+
from torch.optim import AdamW
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
import os
|
| 21 |
+
import shutil
|
| 22 |
+
|
| 23 |
+
# ------------------------------------------------------------------------------
|
| 24 |
+
# βοΈ CONFIGURATION & HYPERPARAMETERS
|
| 25 |
+
# ------------------------------------------------------------------------------
|
| 26 |
+
MODEL_NAME = "microsoft/graphcodebert-base" # π₯ Upgraded to GraphCodeBERT
|
| 27 |
+
MAX_LEN = 512 # Max token length
|
| 28 |
+
BATCH_SIZE = 16 # Training batch size
|
| 29 |
+
EPOCHS = 15 # π₯ Increased from 3 to 15
|
| 30 |
+
LEARNING_RATE = 3e-5 # Optimized initial learning rate
|
| 31 |
+
WEIGHT_DECAY = 0.05 # π₯ Increased Regularization
|
| 32 |
+
PATIENCE = 3 # π₯ Early Stopping patience
|
| 33 |
+
SAVE_PATH = "best_model.pt"
|
| 34 |
+
|
| 35 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 36 |
+
print(f"π₯οΈ Using device: {device}")
|
| 37 |
+
|
| 38 |
+
# ------------------------------------------------------------------------------
|
| 39 |
+
# π DATA PREPARATION
|
| 40 |
+
# ------------------------------------------------------------------------------
|
| 41 |
+
print("\n[1/5] Loading Dataset...")
|
| 42 |
+
dataset = load_dataset("codeparrot/codecomplex")
|
| 43 |
+
df = pd.DataFrame(dataset['train'])
|
| 44 |
+
|
| 45 |
+
# Encode labels
|
| 46 |
+
le = LabelEncoder()
|
| 47 |
+
df['label'] = le.fit_transform(df['complexity'])
|
| 48 |
+
|
| 49 |
+
# Save Label Encoder for Inference
|
| 50 |
+
import joblib
|
| 51 |
+
joblib.dump(le, "label_encoder.pkl")
|
| 52 |
+
|
| 53 |
+
# Train/Test Split (stratified)
|
| 54 |
+
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['label'])
|
| 55 |
+
|
| 56 |
+
# Calculate Class Weights to handle imbalance right from the start
|
| 57 |
+
class_counts = train_df['label'].value_counts().sort_index().values
|
| 58 |
+
total_samples = sum(class_counts)
|
| 59 |
+
class_weights = torch.tensor([total_samples / c for c in class_counts], dtype=torch.float).to(device)
|
| 60 |
+
|
| 61 |
+
print(f"β
Loaded {len(train_df)} training and {len(test_df)} testing samples.")
|
| 62 |
+
|
| 63 |
+
# ------------------------------------------------------------------------------
|
| 64 |
+
# π§ TOKENIZATION & DATASETS
|
| 65 |
+
# ------------------------------------------------------------------------------
|
| 66 |
+
print(f"\n[2/5] Initializing Tokenizer ({MODEL_NAME})...")
|
| 67 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 68 |
+
|
| 69 |
+
class CodeDataset(Dataset):
|
| 70 |
+
def __init__(self, dataframe, tokenizer, max_length=MAX_LEN):
|
| 71 |
+
self.data = dataframe
|
| 72 |
+
self.tokenizer = tokenizer
|
| 73 |
+
self.max_length = max_length
|
| 74 |
+
|
| 75 |
+
def __len__(self):
|
| 76 |
+
return len(self.data)
|
| 77 |
+
|
| 78 |
+
def __getitem__(self, idx):
|
| 79 |
+
code = str(self.data.iloc[idx]['src'])
|
| 80 |
+
label = int(self.data.iloc[idx]['label'])
|
| 81 |
+
|
| 82 |
+
encoding = self.tokenizer(
|
| 83 |
+
code,
|
| 84 |
+
truncation=True,
|
| 85 |
+
max_length=self.max_length,
|
| 86 |
+
padding='max_length',
|
| 87 |
+
return_tensors='pt'
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
return {
|
| 91 |
+
'input_ids': encoding['input_ids'].squeeze(),
|
| 92 |
+
'attention_mask': encoding['attention_mask'].squeeze(),
|
| 93 |
+
'label': torch.tensor(label, dtype=torch.long)
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
train_dataset = CodeDataset(train_df.reset_index(drop=True), tokenizer)
|
| 97 |
+
test_dataset = CodeDataset(test_df.reset_index(drop=True), tokenizer)
|
| 98 |
+
|
| 99 |
+
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
| 100 |
+
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
| 101 |
+
|
| 102 |
+
# ------------------------------------------------------------------------------
|
| 103 |
+
# ποΈ MODEL INITIALIZATION
|
| 104 |
+
# ------------------------------------------------------------------------------
|
| 105 |
+
print(f"\n[3/5] Loading Model ({MODEL_NAME})...")
|
| 106 |
+
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=7)
|
| 107 |
+
model = model.to(device)
|
| 108 |
+
|
| 109 |
+
# Optimizer with Weight Decay
|
| 110 |
+
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
|
| 111 |
+
|
| 112 |
+
# Scheduler
|
| 113 |
+
total_steps = len(train_loader) * EPOCHS
|
| 114 |
+
scheduler = get_linear_schedule_with_warmup(
|
| 115 |
+
optimizer,
|
| 116 |
+
num_warmup_steps=int(total_steps * 0.1), # 10% warmup
|
| 117 |
+
num_training_steps=total_steps
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Loss function with balanced classes
|
| 121 |
+
criterion = nn.CrossEntropyLoss(weight=class_weights)
|
| 122 |
+
|
| 123 |
+
# ------------------------------------------------------------------------------
|
| 124 |
+
# π TRAINING & EVALUATION FUNCTIONS
|
| 125 |
+
# ------------------------------------------------------------------------------
|
| 126 |
+
def train_epoch(model, loader, optimizer, scheduler, criterion, device):
|
| 127 |
+
model.train()
|
| 128 |
+
total_loss, correct, total = 0, 0, 0
|
| 129 |
+
|
| 130 |
+
for batch in tqdm(loader, desc="Training", leave=False):
|
| 131 |
+
input_ids = batch['input_ids'].to(device)
|
| 132 |
+
attention_mask = batch['attention_mask'].to(device)
|
| 133 |
+
labels = batch['label'].to(device)
|
| 134 |
+
|
| 135 |
+
optimizer.zero_grad()
|
| 136 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
| 137 |
+
loss = criterion(outputs.logits, labels)
|
| 138 |
+
|
| 139 |
+
loss.backward()
|
| 140 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 141 |
+
optimizer.step()
|
| 142 |
+
scheduler.step()
|
| 143 |
+
|
| 144 |
+
total_loss += loss.item()
|
| 145 |
+
preds = torch.argmax(outputs.logits, dim=1)
|
| 146 |
+
correct += (preds == labels).sum().item()
|
| 147 |
+
total += labels.size(0)
|
| 148 |
+
|
| 149 |
+
return total_loss / len(loader), correct / total
|
| 150 |
+
|
| 151 |
+
def evaluate(model, loader, device):
|
| 152 |
+
model.eval()
|
| 153 |
+
correct, total = 0, 0
|
| 154 |
+
|
| 155 |
+
with torch.no_grad():
|
| 156 |
+
for batch in tqdm(loader, desc="Evaluating", leave=False):
|
| 157 |
+
input_ids = batch['input_ids'].to(device)
|
| 158 |
+
attention_mask = batch['attention_mask'].to(device)
|
| 159 |
+
labels = batch['label'].to(device)
|
| 160 |
+
|
| 161 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
| 162 |
+
preds = torch.argmax(outputs.logits, dim=1)
|
| 163 |
+
correct += (preds == labels).sum().item()
|
| 164 |
+
total += labels.size(0)
|
| 165 |
+
|
| 166 |
+
return correct / total
|
| 167 |
+
|
| 168 |
+
# ------------------------------------------------------------------------------
|
| 169 |
+
# π₯ MAIN TRAINING LOOP WITH EARLY STOPPING
|
| 170 |
+
# ------------------------------------------------------------------------------
|
| 171 |
+
print("\n[4/5] Starting Training Loop...")
|
| 172 |
+
best_accuracy = 0
|
| 173 |
+
epochs_no_improve = 0
|
| 174 |
+
|
| 175 |
+
for epoch in range(EPOCHS):
|
| 176 |
+
print(f"\nπ Epoch {epoch+1}/{EPOCHS}")
|
| 177 |
+
|
| 178 |
+
train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, criterion, device)
|
| 179 |
+
test_acc = evaluate(model, test_loader, device)
|
| 180 |
+
|
| 181 |
+
print(f"π Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}% | Test Acc: {test_acc*100:.2f}%")
|
| 182 |
+
|
| 183 |
+
# Early Stopping Logic
|
| 184 |
+
if test_acc > best_accuracy:
|
| 185 |
+
best_accuracy = test_acc
|
| 186 |
+
epochs_no_improve = 0
|
| 187 |
+
torch.save(model.state_dict(), SAVE_PATH)
|
| 188 |
+
print(f"β NEW BEST MODEL SAVED! Accuracy: {best_accuracy*100:.2f}%")
|
| 189 |
+
else:
|
| 190 |
+
epochs_no_improve += 1
|
| 191 |
+
print(f"β οΈ No improvement for {epochs_no_improve} epochs.")
|
| 192 |
+
|
| 193 |
+
if epochs_no_improve >= PATIENCE:
|
| 194 |
+
print(f"\nβΉοΈ EARLY STOPPING TRIGGERED! Test accuracy hasn't improved in {PATIENCE} epochs.")
|
| 195 |
+
break
|
| 196 |
+
|
| 197 |
+
# ------------------------------------------------------------------------------
|
| 198 |
+
# πΎ EXPORTING TO DRIVE
|
| 199 |
+
# ------------------------------------------------------------------------------
|
| 200 |
+
print("\n[5/5] Finalizing...")
|
| 201 |
+
try:
|
| 202 |
+
from google.colab import drive
|
| 203 |
+
drive.mount('/content/drive', force_remount=True)
|
| 204 |
+
shutil.copy(SAVE_PATH, f"/content/drive/MyDrive/{SAVE_PATH}")
|
| 205 |
+
shutil.copy("label_encoder.pkl", "/content/drive/MyDrive/label_encoder.pkl")
|
| 206 |
+
print("β
Files successfully backed up to Google Drive!")
|
| 207 |
+
except ImportError:
|
| 208 |
+
print("Not running in Colab - skipping Drive export.")
|
render.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
services:
|
| 2 |
+
- type: web
|
| 3 |
+
name: code-complexity-predictor
|
| 4 |
+
env: docker
|
| 5 |
+
instanceCT: 1
|
| 6 |
+
plan: free
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn
|
| 3 |
+
transformers
|
| 4 |
+
torch
|
| 5 |
+
scikit-learn
|
| 6 |
+
pandas
|
| 7 |
+
joblib
|
| 8 |
+
python-multipart
|
train_extracted.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
!pip install transformers datasets torch scikit-learn
|
| 2 |
+
|
| 3 |
+
# --- CELL ---
|
| 4 |
+
|
| 5 |
+
from datasets import load_dataset
|
| 6 |
+
|
| 7 |
+
dataset = load_dataset("codeparrot/codecomplex")
|
| 8 |
+
print(dataset)
|
| 9 |
+
print(dataset['train'][0])
|
| 10 |
+
|
| 11 |
+
# --- CELL ---
|
| 12 |
+
|
| 13 |
+
import pandas as pd
|
| 14 |
+
|
| 15 |
+
df = pd.DataFrame(dataset['train'])
|
| 16 |
+
|
| 17 |
+
# Check complexity labels
|
| 18 |
+
print("Complexity classes:")
|
| 19 |
+
print(df['complexity'].value_counts())
|
| 20 |
+
|
| 21 |
+
print("\nLanguages:")
|
| 22 |
+
print(df['from'].value_counts())
|
| 23 |
+
|
| 24 |
+
print("\nTotal samples:", len(df))
|
| 25 |
+
|
| 26 |
+
# --- CELL ---
|
| 27 |
+
|
| 28 |
+
from sklearn.preprocessing import LabelEncoder
|
| 29 |
+
from sklearn.model_selection import train_test_split
|
| 30 |
+
|
| 31 |
+
# Encode labels
|
| 32 |
+
le = LabelEncoder()
|
| 33 |
+
df['label'] = le.fit_transform(df['complexity'])
|
| 34 |
+
|
| 35 |
+
print("Label mapping:")
|
| 36 |
+
for i, cls in enumerate(le.classes_):
|
| 37 |
+
print(f" {cls} β {i}")
|
| 38 |
+
|
| 39 |
+
# Split data
|
| 40 |
+
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['label'])
|
| 41 |
+
|
| 42 |
+
print(f"\nTrain size: {len(train_df)}")
|
| 43 |
+
print(f"Test size: {len(test_df)}")
|
| 44 |
+
|
| 45 |
+
# --- CELL ---
|
| 46 |
+
|
| 47 |
+
from transformers import AutoTokenizer
|
| 48 |
+
|
| 49 |
+
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
|
| 50 |
+
|
| 51 |
+
print("β
CodeBERT tokenizer loaded!")
|
| 52 |
+
|
| 53 |
+
# Test it
|
| 54 |
+
sample = df['src'][0][:200]
|
| 55 |
+
tokens = tokenizer(sample, truncation=True, max_length=512, return_tensors="pt")
|
| 56 |
+
print("Sample token shape:", tokens['input_ids'].shape)
|
| 57 |
+
|
| 58 |
+
# --- CELL ---
|
| 59 |
+
|
| 60 |
+
import torch
|
| 61 |
+
from torch.utils.data import Dataset
|
| 62 |
+
|
| 63 |
+
class CodeDataset(Dataset):
|
| 64 |
+
def __init__(self, dataframe, tokenizer, max_length=512):
|
| 65 |
+
self.data = dataframe
|
| 66 |
+
self.tokenizer = tokenizer
|
| 67 |
+
self.max_length = max_length
|
| 68 |
+
|
| 69 |
+
def __len__(self):
|
| 70 |
+
return len(self.data)
|
| 71 |
+
|
| 72 |
+
def __getitem__(self, idx):
|
| 73 |
+
code = str(self.data.iloc[idx]['src'])
|
| 74 |
+
label = int(self.data.iloc[idx]['label'])
|
| 75 |
+
|
| 76 |
+
encoding = self.tokenizer(
|
| 77 |
+
code,
|
| 78 |
+
truncation=True,
|
| 79 |
+
max_length=self.max_length,
|
| 80 |
+
padding='max_length',
|
| 81 |
+
return_tensors='pt'
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
return {
|
| 85 |
+
'input_ids': encoding['input_ids'].squeeze(),
|
| 86 |
+
'attention_mask': encoding['attention_mask'].squeeze(),
|
| 87 |
+
'label': torch.tensor(label, dtype=torch.long)
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
# Create datasets
|
| 91 |
+
train_dataset = CodeDataset(train_df.reset_index(drop=True), tokenizer)
|
| 92 |
+
test_dataset = CodeDataset(test_df.reset_index(drop=True), tokenizer)
|
| 93 |
+
|
| 94 |
+
print(f"β
Train dataset: {len(train_dataset)} samples")
|
| 95 |
+
print(f"β
Test dataset: {len(test_dataset)} samples")
|
| 96 |
+
|
| 97 |
+
# --- CELL ---
|
| 98 |
+
|
| 99 |
+
from transformers import AutoModelForSequenceClassification
|
| 100 |
+
import torch
|
| 101 |
+
|
| 102 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 103 |
+
print(f"Using device: {device}")
|
| 104 |
+
|
| 105 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 106 |
+
"microsoft/codebert-base",
|
| 107 |
+
num_labels=7
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
model = model.to(device)
|
| 111 |
+
print("β
CodeBERT model loaded!")
|
| 112 |
+
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 113 |
+
|
| 114 |
+
# --- CELL ---
|
| 115 |
+
|
| 116 |
+
from torch.utils.data import DataLoader
|
| 117 |
+
from torch.optim import AdamW
|
| 118 |
+
from transformers import get_linear_schedule_with_warmup
|
| 119 |
+
|
| 120 |
+
# DataLoaders
|
| 121 |
+
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
|
| 122 |
+
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
|
| 123 |
+
|
| 124 |
+
# Optimizer
|
| 125 |
+
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
|
| 126 |
+
|
| 127 |
+
# Scheduler
|
| 128 |
+
total_steps = len(train_loader) * 3 # 3 epochs
|
| 129 |
+
scheduler = get_linear_schedule_with_warmup(
|
| 130 |
+
optimizer,
|
| 131 |
+
num_warmup_steps=total_steps // 10,
|
| 132 |
+
num_training_steps=total_steps
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
print(f"β
DataLoaders ready!")
|
| 136 |
+
print(f"Total training steps: {total_steps}")
|
| 137 |
+
print(f"Steps per epoch: {len(train_loader)}")
|
| 138 |
+
|
| 139 |
+
# --- CELL ---
|
| 140 |
+
|
| 141 |
+
from tqdm import tqdm
|
| 142 |
+
|
| 143 |
+
def train_epoch(model, loader, optimizer, scheduler, device):
|
| 144 |
+
model.train()
|
| 145 |
+
total_loss = 0
|
| 146 |
+
correct = 0
|
| 147 |
+
total = 0
|
| 148 |
+
|
| 149 |
+
for batch in tqdm(loader, desc="Training"):
|
| 150 |
+
input_ids = batch['input_ids'].to(device)
|
| 151 |
+
attention_mask = batch['attention_mask'].to(device)
|
| 152 |
+
labels = batch['label'].to(device)
|
| 153 |
+
|
| 154 |
+
optimizer.zero_grad()
|
| 155 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| 156 |
+
loss = outputs.loss
|
| 157 |
+
logits = outputs.logits
|
| 158 |
+
|
| 159 |
+
loss.backward()
|
| 160 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 161 |
+
optimizer.step()
|
| 162 |
+
scheduler.step()
|
| 163 |
+
|
| 164 |
+
total_loss += loss.item()
|
| 165 |
+
preds = torch.argmax(logits, dim=1)
|
| 166 |
+
correct += (preds == labels).sum().item()
|
| 167 |
+
total += labels.size(0)
|
| 168 |
+
|
| 169 |
+
return total_loss / len(loader), correct / total
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def evaluate(model, loader, device):
|
| 173 |
+
model.eval()
|
| 174 |
+
correct = 0
|
| 175 |
+
total = 0
|
| 176 |
+
|
| 177 |
+
with torch.no_grad():
|
| 178 |
+
for batch in tqdm(loader, desc="Evaluating"):
|
| 179 |
+
input_ids = batch['input_ids'].to(device)
|
| 180 |
+
attention_mask = batch['attention_mask'].to(device)
|
| 181 |
+
labels = batch['label'].to(device)
|
| 182 |
+
|
| 183 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
| 184 |
+
preds = torch.argmax(outputs.logits, dim=1)
|
| 185 |
+
correct += (preds == labels).sum().item()
|
| 186 |
+
total += labels.size(0)
|
| 187 |
+
|
| 188 |
+
return correct / total
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
# Train for 3 epochs
|
| 192 |
+
best_accuracy = 0
|
| 193 |
+
|
| 194 |
+
for epoch in range(3):
|
| 195 |
+
print(f"\nπ Epoch {epoch+1}/3")
|
| 196 |
+
train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, device)
|
| 197 |
+
test_acc = evaluate(model, test_loader, device)
|
| 198 |
+
|
| 199 |
+
print(f"Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}% | Test Acc: {test_acc*100:.2f}%")
|
| 200 |
+
|
| 201 |
+
if test_acc > best_accuracy:
|
| 202 |
+
best_accuracy = test_acc
|
| 203 |
+
torch.save(model.state_dict(), "best_model.pt")
|
| 204 |
+
print(f"β
Best model saved! Accuracy: {best_accuracy*100:.2f}%")
|
| 205 |
+
|
| 206 |
+
# --- CELL ---
|
| 207 |
+
|
| 208 |
+
# Train 2 more epochs
|
| 209 |
+
for epoch in range(2):
|
| 210 |
+
print(f"\nπ Epoch {epoch+4}/5")
|
| 211 |
+
train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, device)
|
| 212 |
+
test_acc = evaluate(model, test_loader, device)
|
| 213 |
+
|
| 214 |
+
print(f"Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}% | Test Acc: {test_acc*100:.2f}%")
|
| 215 |
+
|
| 216 |
+
if test_acc > best_accuracy:
|
| 217 |
+
best_accuracy = test_acc
|
| 218 |
+
torch.save(model.state_dict(), "best_model.pt")
|
| 219 |
+
print(f"β
Best model saved! Accuracy: {best_accuracy*100:.2f}%")
|
| 220 |
+
|
| 221 |
+
# --- CELL ---
|
| 222 |
+
|
| 223 |
+
from google.colab import drive
|
| 224 |
+
drive.mount('/content/drive')
|
| 225 |
+
|
| 226 |
+
# --- CELL ---
|
| 227 |
+
|
| 228 |
+
import shutil
|
| 229 |
+
|
| 230 |
+
# Copy files to Google Drive
|
| 231 |
+
shutil.copy("best_model.pt", "/content/drive/MyDrive/best_model.pt")
|
| 232 |
+
shutil.copy("label_encoder.pkl", "/content/drive/MyDrive/label_encoder.pkl")
|
| 233 |
+
|
| 234 |
+
print("β
Files saved to Google Drive!")
|
| 235 |
+
|
| 236 |
+
# --- CELL ---
|
| 237 |
+
|
| 238 |
+
# Test the model directly in Colab
|
| 239 |
+
test_codes = [
|
| 240 |
+
"public int findMax(int[] arr) { int max = arr[0]; for (int i = 1; i < arr.length; i++) { if (arr[i] > max) max = arr[i]; } return max; }",
|
| 241 |
+
"return arr[0];",
|
| 242 |
+
"for(int i=0;i<n;i++) for(int j=0;j<n;j++) sum+=arr[i][j];",
|
| 243 |
+
]
|
| 244 |
+
|
| 245 |
+
for code in test_codes:
|
| 246 |
+
inputs = tokenizer(code, truncation=True, max_length=512, padding='max_length', return_tensors='pt')
|
| 247 |
+
input_ids = inputs['input_ids'].to(device)
|
| 248 |
+
attention_mask = inputs['attention_mask'].to(device)
|
| 249 |
+
|
| 250 |
+
with torch.no_grad():
|
| 251 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
| 252 |
+
pred = torch.argmax(outputs.logits, dim=1).item()
|
| 253 |
+
|
| 254 |
+
print(f"Code: {code[:50]}...")
|
| 255 |
+
print(f"Predicted: {le.inverse_transform([pred])[0]}\n")
|
| 256 |
+
|
| 257 |
+
# --- CELL ---
|
| 258 |
+
|
| 259 |
+
import torch.nn as nn
|
| 260 |
+
|
| 261 |
+
# Count class frequencies
|
| 262 |
+
class_counts = df['label'].value_counts().sort_index().values
|
| 263 |
+
total = sum(class_counts)
|
| 264 |
+
class_weights = torch.tensor([total/c for c in class_counts], dtype=torch.float).to(device)
|
| 265 |
+
|
| 266 |
+
print("Class weights:", class_weights)
|
| 267 |
+
|
| 268 |
+
# New training loop with weighted loss
|
| 269 |
+
def train_epoch_weighted(model, loader, optimizer, scheduler, device, weights):
|
| 270 |
+
model.train()
|
| 271 |
+
total_loss = 0
|
| 272 |
+
correct = 0
|
| 273 |
+
total = 0
|
| 274 |
+
criterion = nn.CrossEntropyLoss(weight=weights)
|
| 275 |
+
|
| 276 |
+
for batch in tqdm(loader, desc="Training"):
|
| 277 |
+
input_ids = batch['input_ids'].to(device)
|
| 278 |
+
attention_mask = batch['attention_mask'].to(device)
|
| 279 |
+
labels = batch['label'].to(device)
|
| 280 |
+
|
| 281 |
+
optimizer.zero_grad()
|
| 282 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
| 283 |
+
loss = criterion(outputs.logits, labels)
|
| 284 |
+
|
| 285 |
+
loss.backward()
|
| 286 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 287 |
+
optimizer.step()
|
| 288 |
+
scheduler.step()
|
| 289 |
+
|
| 290 |
+
total_loss += loss.item()
|
| 291 |
+
preds = torch.argmax(outputs.logits, dim=1)
|
| 292 |
+
correct += (preds == labels).sum().item()
|
| 293 |
+
total += labels.size(0)
|
| 294 |
+
|
| 295 |
+
return total_loss / len(loader), correct / total
|
| 296 |
+
|
| 297 |
+
# Retrain with weights
|
| 298 |
+
optimizer3 = AdamW(model.parameters(), lr=5e-6)
|
| 299 |
+
scheduler3 = get_linear_schedule_with_warmup(optimizer3, num_warmup_steps=30, num_training_steps=len(train_loader)*3)
|
| 300 |
+
|
| 301 |
+
for epoch in range(3):
|
| 302 |
+
print(f"\nπ Epoch {epoch+1}/3")
|
| 303 |
+
train_loss, train_acc = train_epoch_weighted(model, train_loader, optimizer3, scheduler3, device, class_weights)
|
| 304 |
+
test_acc = evaluate(model, test_loader, device)
|
| 305 |
+
print(f"Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}% | Test Acc: {test_acc*100:.2f}%")
|
| 306 |
+
if test_acc > best_accuracy:
|
| 307 |
+
best_accuracy = test_acc
|
| 308 |
+
torch.save(model.state_dict(), "best_model.pt")
|
| 309 |
+
print(f"β
Best model saved! Accuracy: {best_accuracy*100:.2f}%")
|
| 310 |
+
|
| 311 |
+
# --- CELL ---
|
| 312 |
+
|
| 313 |
+
import shutil
|
| 314 |
+
shutil.copy("best_model.pt", "/content/drive/MyDrive/best_model.pt")
|
| 315 |
+
print("β
Saved to Google Drive!")
|