Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torch.nn.functional as F | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| from model import ComplexityFusionModel | |
| from features import clean_code, get_python_features, get_java_features | |
| from explainer import generate_shap_explanation | |
| # API SETUP | |
| app = FastAPI(title="Code Complexity XAI API", version="1.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| label_map = {0: 'CONSTANT', 1: 'LINEAR', 2: 'LOGN', 3: 'NLOGN', 4: 'QUADRATIC', 5: 'CUBIC', 6: 'NP'} | |
| REPO_ID = "himansha2001/algox" | |
| print("Booting up backend services...") | |
| tokenizer = AutoTokenizer.from_pretrained(REPO_ID) | |
| model = ComplexityFusionModel(model_name="microsoft/unixcoder-base", num_labels=7, num_static_features=5) | |
| safetensors_path = hf_hub_download(repo_id=REPO_ID, filename="model.safetensors") | |
| state_dict = load_file(safetensors_path) | |
| model.load_state_dict(state_dict, strict=False) | |
| model.to(device) | |
| model.eval() | |
| print("API is ready for inference!") | |
| class CodeRequest(BaseModel): | |
| code: str | |
| language: str | |
| async def health_check(): | |
| """ | |
| Root endpoint to verify the API is online and the model is loaded. | |
| """ | |
| return { | |
| "status": "online", | |
| "message": "Code Complexity XAI API is running successfully.", | |
| "model_loaded": True, | |
| "version": "1.0.0" | |
| } | |
| async def predict_complexity(request: CodeRequest): | |
| """ | |
| Endpoint to predict the complexity of the provided code and generate an explanation. | |
| """ | |
| lang = request.language.lower() | |
| # Prepare Data | |
| cleaned_code = clean_code(request.code, lang) | |
| if lang == 'python': | |
| feats = get_python_features(request.code) | |
| elif lang == 'java': | |
| feats = get_java_features(request.code) | |
| else: | |
| raise HTTPException(status_code=400, detail="Language must be 'java' or 'python'") | |
| request_static_features = torch.tensor([feats], dtype=torch.float32).to(device) | |
| # Tokenize & Predict | |
| inputs = tokenizer(cleaned_code, return_tensors="pt", truncation=True, max_length=512).to(device) | |
| with torch.no_grad(): | |
| logits = model( | |
| input_ids=inputs['input_ids'], | |
| attention_mask=inputs['attention_mask'], | |
| static_features=request_static_features | |
| ) | |
| probs = F.softmax(logits, dim=1) | |
| pred_idx = probs.argmax().item() | |
| confidence = probs.max().item() | |
| prediction = label_map[pred_idx] | |
| # Generate SHAP Explanation | |
| shap_explanation = generate_shap_explanation( | |
| cleaned_code=cleaned_code, | |
| model=model, | |
| tokenizer=tokenizer, | |
| static_features_tensor=request_static_features, | |
| device=device, | |
| pred_idx=pred_idx, | |
| label_map=label_map | |
| ) | |
| # Return Response | |
| return { | |
| "complexity": prediction, | |
| "confidence": float(confidence), | |
| "static_features": { | |
| "max_depth": feats[0], | |
| "branch_count": feats[1], | |
| "has_recursion": bool(feats[2]), | |
| "has_log_math": bool(feats[3]), | |
| "has_sort": bool(feats[4]) | |
| }, | |
| "shap_explanation": shap_explanation | |
| } |