Spaces:
Running
Running
Update app/main.py
Browse files- app/main.py +44 -36
app/main.py
CHANGED
|
@@ -8,7 +8,7 @@ from typing import List, Optional, Dict
|
|
| 8 |
from concurrent.futures import ThreadPoolExecutor
|
| 9 |
from dotenv import load_dotenv
|
| 10 |
|
| 11 |
-
from fastapi import FastAPI, HTTPException
|
| 12 |
from pydantic import BaseModel
|
| 13 |
import uvicorn
|
| 14 |
|
|
@@ -16,7 +16,6 @@ import uvicorn
|
|
| 16 |
load_dotenv()
|
| 17 |
|
| 18 |
# Import internal services
|
| 19 |
-
# Ensure these modules exist in your /app directory
|
| 20 |
from app.predictor import classifier, guide_generator, reviewer
|
| 21 |
from app.core.model_loader import llm_engine
|
| 22 |
|
|
@@ -28,7 +27,6 @@ logger = logging.getLogger(__name__)
|
|
| 28 |
app = FastAPI(title="GitGud AI Service")
|
| 29 |
|
| 30 |
# Global embedding cache
|
| 31 |
-
# Structure: { "repo_name": { "file_path": [embedding_vector] } }
|
| 32 |
REPO_CACHE: Dict[str, Dict[str, List[float]]] = {}
|
| 33 |
|
| 34 |
# 3. Data Models
|
|
@@ -56,23 +54,31 @@ class ChatRequest(BaseModel):
|
|
| 56 |
|
| 57 |
# 4. Core Logic Helpers
|
| 58 |
def calculate_repo_health(total_vulns: int, avg_maint: float) -> int:
|
| 59 |
-
"""
|
| 60 |
-
Standardized health logic to keep Dashboard and Review metrics in sync.
|
| 61 |
-
Uses a weighted scale of maintainability vs security risk.
|
| 62 |
-
"""
|
| 63 |
-
# Base starts from the maintainability average (0.0 to 10.0 scale mapped to 100)
|
| 64 |
base_score = avg_maint * 10
|
| 65 |
-
# Penalty: Subtract 8 points per vulnerability
|
| 66 |
penalty = total_vulns * 8
|
| 67 |
-
|
| 68 |
final_score = base_score - penalty
|
| 69 |
return int(max(10, min(100, final_score)))
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
# 5. Endpoints
|
| 72 |
|
| 73 |
@app.get("/")
|
| 74 |
def health_check():
|
| 75 |
-
"""Checks server status, GPU availability, and cached data."""
|
| 76 |
return {
|
| 77 |
"status": "online",
|
| 78 |
"model": "microsoft/codebert-base",
|
|
@@ -82,15 +88,12 @@ def health_check():
|
|
| 82 |
|
| 83 |
@app.get("/usage")
|
| 84 |
def get_usage():
|
| 85 |
-
"""Returns AI Service usage statistics from the LLM engine."""
|
| 86 |
return llm_engine.get_usage_stats()
|
| 87 |
|
| 88 |
@app.post("/classify")
|
| 89 |
async def classify_file(request: FileRequest):
|
| 90 |
-
"""Classifies file into architectural layers and caches embeddings."""
|
| 91 |
try:
|
| 92 |
result = classifier.predict(request.fileName, request.content)
|
| 93 |
-
|
| 94 |
if request.repoName:
|
| 95 |
if request.repoName not in REPO_CACHE:
|
| 96 |
REPO_CACHE[request.repoName] = {}
|
|
@@ -108,7 +111,6 @@ async def classify_file(request: FileRequest):
|
|
| 108 |
|
| 109 |
@app.post("/review-batch-code")
|
| 110 |
async def review_batch_code(request: BatchReviewRequest):
|
| 111 |
-
"""Detailed review results for the Review Section UI."""
|
| 112 |
try:
|
| 113 |
loop = asyncio.get_event_loop()
|
| 114 |
with ThreadPoolExecutor() as executor:
|
|
@@ -117,14 +119,14 @@ async def review_batch_code(request: BatchReviewRequest):
|
|
| 117 |
)
|
| 118 |
return {"results": reviews}
|
| 119 |
except Exception as e:
|
|
|
|
|
|
|
| 120 |
traceback.print_exc()
|
| 121 |
raise HTTPException(status_code=500, detail=str(e))
|
| 122 |
|
| 123 |
@app.post("/repo-dashboard-stats")
|
| 124 |
async def get_dashboard_stats(request: BatchReviewRequest):
|
| 125 |
-
"""Aggregated stats using the exact same logic as batch review."""
|
| 126 |
try:
|
| 127 |
-
# Run heavy AI review in thread pool to keep FastAPI responsive
|
| 128 |
loop = asyncio.get_event_loop()
|
| 129 |
with ThreadPoolExecutor() as executor:
|
| 130 |
raw_reviews = await loop.run_in_executor(
|
|
@@ -137,15 +139,11 @@ async def get_dashboard_stats(request: BatchReviewRequest):
|
|
| 137 |
api_regex = re.compile(r'(?:get|post|put|delete|patch)\([\'"]\/(.*?)[\'"]', re.IGNORECASE)
|
| 138 |
|
| 139 |
for i, review in enumerate(raw_reviews):
|
| 140 |
-
# Sync Vulnerability count
|
| 141 |
vulns = review.get("vulnerabilities", [])
|
| 142 |
total_vulns += len(vulns)
|
| 143 |
-
|
| 144 |
-
# Sync Maintainability
|
| 145 |
m_score = review.get("metrics", {}).get("maintainability", 8.0)
|
| 146 |
maint_scores.append(m_score)
|
| 147 |
|
| 148 |
-
# Extract APIs (No cap, show all discovered)
|
| 149 |
content = request.files[i].content
|
| 150 |
if content:
|
| 151 |
matches = api_regex.findall(content)
|
|
@@ -154,8 +152,6 @@ async def get_dashboard_stats(request: BatchReviewRequest):
|
|
| 154 |
|
| 155 |
num_files = len(maint_scores)
|
| 156 |
avg_maint = (sum(maint_scores) / num_files) if num_files > 0 else 0
|
| 157 |
-
|
| 158 |
-
# Calculate health using shared logic
|
| 159 |
health_score = calculate_repo_health(total_vulns, avg_maint)
|
| 160 |
|
| 161 |
return {
|
|
@@ -168,13 +164,13 @@ async def get_dashboard_stats(request: BatchReviewRequest):
|
|
| 168 |
"average_maintainability": round(avg_maint, 1)
|
| 169 |
}
|
| 170 |
except Exception as e:
|
|
|
|
|
|
|
| 171 |
logger.error(f"Dashboard stats failed: {e}")
|
| 172 |
-
traceback.print_exc()
|
| 173 |
raise HTTPException(status_code=500, detail="Failed to sync dashboard metrics")
|
| 174 |
|
| 175 |
@app.post("/analyze-file")
|
| 176 |
async def analyze_file(request: FileRequest):
|
| 177 |
-
"""Deep analysis: Summary, Tags, and Layer Classification."""
|
| 178 |
try:
|
| 179 |
result = classifier.predict(request.fileName, request.content)
|
| 180 |
summary = classifier.generate_file_summary(request.content, request.fileName)
|
|
@@ -188,12 +184,12 @@ async def analyze_file(request: FileRequest):
|
|
| 188 |
"embedding": result["embedding"],
|
| 189 |
}
|
| 190 |
except Exception as e:
|
| 191 |
-
|
|
|
|
| 192 |
raise HTTPException(status_code=500, detail=str(e))
|
| 193 |
|
| 194 |
@app.post("/semantic-search")
|
| 195 |
async def semantic_search(request: SearchRequest):
|
| 196 |
-
"""Search code using natural language and vector similarity."""
|
| 197 |
try:
|
| 198 |
embeddings = request.embeddings
|
| 199 |
if not embeddings and request.repoName and request.repoName in REPO_CACHE:
|
|
@@ -205,12 +201,10 @@ async def semantic_search(request: SearchRequest):
|
|
| 205 |
results = classifier.semantic_search(request.query, embeddings)
|
| 206 |
return {"results": results}
|
| 207 |
except Exception as e:
|
| 208 |
-
traceback.print_exc()
|
| 209 |
raise HTTPException(status_code=500, detail=str(e))
|
| 210 |
|
| 211 |
@app.post("/chat")
|
| 212 |
async def chat(request: ChatRequest):
|
| 213 |
-
"""RAG-based chat using provided file context."""
|
| 214 |
start_time = time.time()
|
| 215 |
try:
|
| 216 |
context_str = ""
|
|
@@ -224,21 +218,35 @@ async def chat(request: ChatRequest):
|
|
| 224 |
USER QUESTION: {request.query}
|
| 225 |
"""
|
| 226 |
response = llm_engine.generate_text(prompt)
|
| 227 |
-
|
| 228 |
-
logger.info(f"Chat generated in {time.time() - start_time:.2f}s")
|
| 229 |
-
return {"response": response}
|
| 230 |
except Exception as e:
|
| 231 |
-
|
|
|
|
|
|
|
| 232 |
raise HTTPException(status_code=500, detail=str(e))
|
| 233 |
|
| 234 |
@app.post("/generate-guide")
|
| 235 |
async def generate_guide(request: GuideRequest):
|
| 236 |
-
"""
|
|
|
|
|
|
|
| 237 |
try:
|
| 238 |
markdown = guide_generator.generate_markdown(request.repoName, request.filePaths)
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
except Exception as e:
|
| 241 |
-
|
|
|
|
| 242 |
raise HTTPException(status_code=500, detail=str(e))
|
| 243 |
|
| 244 |
# 6. Application Entry Point
|
|
|
|
| 8 |
from concurrent.futures import ThreadPoolExecutor
|
| 9 |
from dotenv import load_dotenv
|
| 10 |
|
| 11 |
+
from fastapi import FastAPI, HTTPException, status
|
| 12 |
from pydantic import BaseModel
|
| 13 |
import uvicorn
|
| 14 |
|
|
|
|
| 16 |
load_dotenv()
|
| 17 |
|
| 18 |
# Import internal services
|
|
|
|
| 19 |
from app.predictor import classifier, guide_generator, reviewer
|
| 20 |
from app.core.model_loader import llm_engine
|
| 21 |
|
|
|
|
| 27 |
app = FastAPI(title="GitGud AI Service")
|
| 28 |
|
| 29 |
# Global embedding cache
|
|
|
|
| 30 |
REPO_CACHE: Dict[str, Dict[str, List[float]]] = {}
|
| 31 |
|
| 32 |
# 3. Data Models
|
|
|
|
| 54 |
|
| 55 |
# 4. Core Logic Helpers
|
| 56 |
def calculate_repo_health(total_vulns: int, avg_maint: float) -> int:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
base_score = avg_maint * 10
|
|
|
|
| 58 |
penalty = total_vulns * 8
|
|
|
|
| 59 |
final_score = base_score - penalty
|
| 60 |
return int(max(10, min(100, final_score)))
|
| 61 |
|
| 62 |
+
def parse_tree_to_list(raw_tree: str):
|
| 63 |
+
"""Parses text tree into JSON objects for Compose LazyColumn"""
|
| 64 |
+
nodes = []
|
| 65 |
+
for line in raw_tree.strip().split('\n'):
|
| 66 |
+
# Detect depth based on tree characters
|
| 67 |
+
level = line.count('|') + (line.count(' ') // 2)
|
| 68 |
+
# Clean the name
|
| 69 |
+
name = re.sub(r'[|└├─]', '', line).strip()
|
| 70 |
+
if name:
|
| 71 |
+
nodes.append({
|
| 72 |
+
"name": name,
|
| 73 |
+
"type": "file" if '.' in name else "folder",
|
| 74 |
+
"level": level
|
| 75 |
+
})
|
| 76 |
+
return nodes
|
| 77 |
+
|
| 78 |
# 5. Endpoints
|
| 79 |
|
| 80 |
@app.get("/")
|
| 81 |
def health_check():
|
|
|
|
| 82 |
return {
|
| 83 |
"status": "online",
|
| 84 |
"model": "microsoft/codebert-base",
|
|
|
|
| 88 |
|
| 89 |
@app.get("/usage")
|
| 90 |
def get_usage():
|
|
|
|
| 91 |
return llm_engine.get_usage_stats()
|
| 92 |
|
| 93 |
@app.post("/classify")
|
| 94 |
async def classify_file(request: FileRequest):
|
|
|
|
| 95 |
try:
|
| 96 |
result = classifier.predict(request.fileName, request.content)
|
|
|
|
| 97 |
if request.repoName:
|
| 98 |
if request.repoName not in REPO_CACHE:
|
| 99 |
REPO_CACHE[request.repoName] = {}
|
|
|
|
| 111 |
|
| 112 |
@app.post("/review-batch-code")
|
| 113 |
async def review_batch_code(request: BatchReviewRequest):
|
|
|
|
| 114 |
try:
|
| 115 |
loop = asyncio.get_event_loop()
|
| 116 |
with ThreadPoolExecutor() as executor:
|
|
|
|
| 119 |
)
|
| 120 |
return {"results": reviews}
|
| 121 |
except Exception as e:
|
| 122 |
+
if "429" in str(e):
|
| 123 |
+
raise HTTPException(status_code=429, detail="AI Quota Exceeded")
|
| 124 |
traceback.print_exc()
|
| 125 |
raise HTTPException(status_code=500, detail=str(e))
|
| 126 |
|
| 127 |
@app.post("/repo-dashboard-stats")
|
| 128 |
async def get_dashboard_stats(request: BatchReviewRequest):
|
|
|
|
| 129 |
try:
|
|
|
|
| 130 |
loop = asyncio.get_event_loop()
|
| 131 |
with ThreadPoolExecutor() as executor:
|
| 132 |
raw_reviews = await loop.run_in_executor(
|
|
|
|
| 139 |
api_regex = re.compile(r'(?:get|post|put|delete|patch)\([\'"]\/(.*?)[\'"]', re.IGNORECASE)
|
| 140 |
|
| 141 |
for i, review in enumerate(raw_reviews):
|
|
|
|
| 142 |
vulns = review.get("vulnerabilities", [])
|
| 143 |
total_vulns += len(vulns)
|
|
|
|
|
|
|
| 144 |
m_score = review.get("metrics", {}).get("maintainability", 8.0)
|
| 145 |
maint_scores.append(m_score)
|
| 146 |
|
|
|
|
| 147 |
content = request.files[i].content
|
| 148 |
if content:
|
| 149 |
matches = api_regex.findall(content)
|
|
|
|
| 152 |
|
| 153 |
num_files = len(maint_scores)
|
| 154 |
avg_maint = (sum(maint_scores) / num_files) if num_files > 0 else 0
|
|
|
|
|
|
|
| 155 |
health_score = calculate_repo_health(total_vulns, avg_maint)
|
| 156 |
|
| 157 |
return {
|
|
|
|
| 164 |
"average_maintainability": round(avg_maint, 1)
|
| 165 |
}
|
| 166 |
except Exception as e:
|
| 167 |
+
if "429" in str(e):
|
| 168 |
+
raise HTTPException(status_code=429, detail="Quota exceeded")
|
| 169 |
logger.error(f"Dashboard stats failed: {e}")
|
|
|
|
| 170 |
raise HTTPException(status_code=500, detail="Failed to sync dashboard metrics")
|
| 171 |
|
| 172 |
@app.post("/analyze-file")
|
| 173 |
async def analyze_file(request: FileRequest):
|
|
|
|
| 174 |
try:
|
| 175 |
result = classifier.predict(request.fileName, request.content)
|
| 176 |
summary = classifier.generate_file_summary(request.content, request.fileName)
|
|
|
|
| 184 |
"embedding": result["embedding"],
|
| 185 |
}
|
| 186 |
except Exception as e:
|
| 187 |
+
if "429" in str(e):
|
| 188 |
+
raise HTTPException(status_code=429, detail="Limit Reached")
|
| 189 |
raise HTTPException(status_code=500, detail=str(e))
|
| 190 |
|
| 191 |
@app.post("/semantic-search")
|
| 192 |
async def semantic_search(request: SearchRequest):
|
|
|
|
| 193 |
try:
|
| 194 |
embeddings = request.embeddings
|
| 195 |
if not embeddings and request.repoName and request.repoName in REPO_CACHE:
|
|
|
|
| 201 |
results = classifier.semantic_search(request.query, embeddings)
|
| 202 |
return {"results": results}
|
| 203 |
except Exception as e:
|
|
|
|
| 204 |
raise HTTPException(status_code=500, detail=str(e))
|
| 205 |
|
| 206 |
@app.post("/chat")
|
| 207 |
async def chat(request: ChatRequest):
|
|
|
|
| 208 |
start_time = time.time()
|
| 209 |
try:
|
| 210 |
context_str = ""
|
|
|
|
| 218 |
USER QUESTION: {request.query}
|
| 219 |
"""
|
| 220 |
response = llm_engine.generate_text(prompt)
|
| 221 |
+
return {"response": response, "status": "success"}
|
|
|
|
|
|
|
| 222 |
except Exception as e:
|
| 223 |
+
if "429" in str(e):
|
| 224 |
+
# Return a structured error so Compose can show a nice UI
|
| 225 |
+
return {"response": "⚠️ Daily limit reached. Try again in a bit!", "status": "quota_error"}
|
| 226 |
raise HTTPException(status_code=500, detail=str(e))
|
| 227 |
|
| 228 |
@app.post("/generate-guide")
|
| 229 |
async def generate_guide(request: GuideRequest):
|
| 230 |
+
"""
|
| 231 |
+
FIXED: Now parses the messy tree text into JSON for your Compose UI.
|
| 232 |
+
"""
|
| 233 |
try:
|
| 234 |
markdown = guide_generator.generate_markdown(request.repoName, request.filePaths)
|
| 235 |
+
|
| 236 |
+
# Extract the messy tree part and clean it
|
| 237 |
+
tree_match = re.search(r"Project Structure\n\n(.*?)(?=\n\n|$)", markdown, re.S)
|
| 238 |
+
structured_tree = []
|
| 239 |
+
if tree_match:
|
| 240 |
+
structured_tree = parse_tree_to_list(tree_match.group(1))
|
| 241 |
+
|
| 242 |
+
return {
|
| 243 |
+
"markdown": markdown,
|
| 244 |
+
"structured_tree": structured_tree,
|
| 245 |
+
"project_name": request.repoName
|
| 246 |
+
}
|
| 247 |
except Exception as e:
|
| 248 |
+
if "429" in str(e):
|
| 249 |
+
raise HTTPException(status_code=429, detail="AI Quota Exceeded")
|
| 250 |
raise HTTPException(status_code=500, detail=str(e))
|
| 251 |
|
| 252 |
# 6. Application Entry Point
|