gitgud-ai / app /main.py
Despressoo's picture
Update app/main.py
3f95678 verified
raw
history blame
9.96 kB
import os
import re
import logging
import traceback
import time
import asyncio
from typing import List, Optional, Dict
from concurrent.futures import ThreadPoolExecutor
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, status
from pydantic import BaseModel
import uvicorn
# Load environment variables
load_dotenv()
# Import internal services
from app.predictor import classifier, guide_generator, reviewer
from app.core.model_loader import llm_engine
# 1. Setup Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 2. Initialize FastAPI
app = FastAPI(title="GitGud AI Service")
# Global embedding cache
REPO_CACHE: Dict[str, Dict[str, List[float]]] = {}
# 3. Data Models
class FileRequest(BaseModel):
fileName: str
content: Optional[str] = None
repoName: Optional[str] = None
class BatchReviewRequest(BaseModel):
files: List[FileRequest]
class GuideRequest(BaseModel):
repoName: str
filePaths: List[str]
class SearchRequest(BaseModel):
query: str
embeddings: Optional[Dict[str, List[float]]] = None
repoName: Optional[str] = None
class ChatRequest(BaseModel):
query: str
context: List[Dict[str, str]]
repoName: str
# 4. Core Logic Helpers
def calculate_repo_health(total_vulns: int, avg_maint: float) -> int:
base_score = avg_maint * 10
penalty = total_vulns * 8
final_score = base_score - penalty
return int(max(10, min(100, final_score)))
def parse_tree_to_list(raw_tree: str):
"""Parses text tree into JSON objects for Compose LazyColumn"""
nodes = []
for line in raw_tree.strip().split('\n'):
# Detect depth based on tree characters
level = line.count('|') + (line.count(' ') // 2)
# Clean the name
name = re.sub(r'[|└├─]', '', line).strip()
if name:
nodes.append({
"name": name,
"type": "file" if '.' in name else "folder",
"level": level
})
return nodes
# 5. Endpoints
@app.get("/")
def health_check():
return {
"status": "online",
"model": "microsoft/codebert-base",
"device": getattr(classifier, "device", "cpu"),
"cached_repos": list(REPO_CACHE.keys()),
}
@app.get("/usage")
def get_usage():
return llm_engine.get_usage_stats()
@app.post("/classify")
async def classify_file(request: FileRequest):
try:
result = classifier.predict(request.fileName, request.content)
if request.repoName:
if request.repoName not in REPO_CACHE:
REPO_CACHE[request.repoName] = {}
REPO_CACHE[request.repoName][request.fileName] = result["embedding"]
return {
"fileName": request.fileName,
"layer": result["label"],
"confidence": result["confidence"],
"embedding": result["embedding"]
}
except Exception as e:
logger.error(f"Classify failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/review-batch-code")
async def review_batch_code(request: BatchReviewRequest):
try:
loop = asyncio.get_event_loop()
# Helper function to process one file at a time
def process_single_file(file_req):
# We wrap it in a list so your existing service method still accepts it,
# but it only processes 1 file per thread.
return reviewer.service.review_batch_code([file_req])
# Spin up a ThreadPoolExecutor with up to 15 concurrent workers
with ThreadPoolExecutor(max_workers=15) as executor:
# Create a concurrent task for every file in the request
tasks = [
loop.run_in_executor(executor, process_single_file, f)
for f in request.files
]
# asyncio.gather fires them all off at the exact same time
raw_reviews = await asyncio.gather(*tasks, return_exceptions=True)
# Clean up the results and handle any individual file failures gracefully
valid_reviews = []
for i, result in enumerate(raw_reviews):
if isinstance(result, Exception):
# If we hit a rate limit, bubble it up immediately
if "429" in str(result):
raise HTTPException(status_code=429, detail="AI Quota Exceeded")
# Otherwise, log the specific file error but don't crash the whole batch
logger.error(f"Failed to analyze {request.files[i].fileName}: {result}")
else:
valid_reviews.append(result)
return {"results": valid_reviews}
except Exception as e:
if isinstance(e, HTTPException):
raise e
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
@app.post("/repo-dashboard-stats")
async def get_dashboard_stats(request: BatchReviewRequest):
try:
loop = asyncio.get_event_loop()
with ThreadPoolExecutor() as executor:
raw_reviews = await loop.run_in_executor(
executor, reviewer.service.review_batch_code, request.files
)
total_vulns = 0
maint_scores = []
found_apis = set()
api_regex = re.compile(r'(?:get|post|put|delete|patch)\([\'"]\/(.*?)[\'"]', re.IGNORECASE)
for i, review in enumerate(raw_reviews):
vulns = review.get("vulnerabilities", [])
total_vulns += len(vulns)
m_score = review.get("metrics", {}).get("maintainability", 8.0)
maint_scores.append(m_score)
content = request.files[i].content
if content:
matches = api_regex.findall(content)
for match in matches:
found_apis.add(f"/{match}")
num_files = len(maint_scores)
avg_maint = (sum(maint_scores) / num_files) if num_files > 0 else 0
health_score = calculate_repo_health(total_vulns, avg_maint)
return {
"repo_health": health_score,
"health_label": "Excellent" if health_score > 85 else "Good" if health_score > 60 else "Critical",
"security_issues": total_vulns,
"performance_ratio": f"{int(avg_maint * 10)}%",
"exposed_apis": list(found_apis),
"total_files_processed": num_files,
"average_maintainability": round(avg_maint, 1)
}
except Exception as e:
if "429" in str(e):
raise HTTPException(status_code=429, detail="Quota exceeded")
logger.error(f"Dashboard stats failed: {e}")
raise HTTPException(status_code=500, detail="Failed to sync dashboard metrics")
@app.post("/analyze-file")
async def analyze_file(request: FileRequest):
try:
result = classifier.predict(request.fileName, request.content)
summary = classifier.generate_file_summary(request.content, request.fileName)
tags = classifier.extract_tags(request.content, request.fileName)
return {
"fileName": request.fileName,
"layer": result["label"],
"summary": summary,
"tags": tags,
"embedding": result["embedding"],
}
except Exception as e:
if "429" in str(e):
raise HTTPException(status_code=429, detail="Limit Reached")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/semantic-search")
async def semantic_search(request: SearchRequest):
try:
embeddings = request.embeddings
if not embeddings and request.repoName and request.repoName in REPO_CACHE:
embeddings = REPO_CACHE[request.repoName]
if not embeddings:
return {"results": []}
results = classifier.semantic_search(request.query, embeddings)
return {"results": results}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/chat")
async def chat(request: ChatRequest):
start_time = time.time()
try:
context_str = ""
for item in request.context:
context_str += f"--- FILE: {item['fileName']} ---\n{item['content']}\n\n"
prompt = f"""
You are "GitGud AI", an expert software architect.
Repository: "{request.repoName}"
CONTEXT: {context_str if request.context else "(NO CODE PROVIDED)"}
USER QUESTION: {request.query}
"""
response = llm_engine.generate_text(prompt)
return {"response": response, "status": "success"}
except Exception as e:
if "429" in str(e):
# Return a structured error so Compose can show a nice UI
return {"response": "⚠️ Daily limit reached. Try again in a bit!", "status": "quota_error"}
raise HTTPException(status_code=500, detail=str(e))
@app.post("/generate-guide")
async def generate_guide(request: GuideRequest):
"""
FIXED: Now parses the messy tree text into JSON for your Compose UI.
"""
try:
markdown = guide_generator.generate_markdown(request.repoName, request.filePaths)
# Extract the messy tree part and clean it
tree_match = re.search(r"Project Structure\n\n(.*?)(?=\n\n|$)", markdown, re.S)
structured_tree = []
if tree_match:
structured_tree = parse_tree_to_list(tree_match.group(1))
return {
"markdown": markdown,
"structured_tree": structured_tree,
"project_name": request.repoName
}
except Exception as e:
if "429" in str(e):
raise HTTPException(status_code=429, detail="AI Quota Exceeded")
raise HTTPException(status_code=500, detail=str(e))
# 6. Application Entry Point
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)