import os import re import logging import traceback import time import asyncio from typing import List, Optional, Dict from concurrent.futures import ThreadPoolExecutor from dotenv import load_dotenv from fastapi import FastAPI, HTTPException, status from pydantic import BaseModel import uvicorn # Load environment variables load_dotenv() # Import internal services from app.predictor import classifier, guide_generator, reviewer from app.core.model_loader import llm_engine # 1. Setup Logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 2. Initialize FastAPI app = FastAPI(title="GitGud AI Service") # Global embedding cache REPO_CACHE: Dict[str, Dict[str, List[float]]] = {} # 3. Data Models class FileRequest(BaseModel): fileName: str content: Optional[str] = None repoName: Optional[str] = None class BatchReviewRequest(BaseModel): files: List[FileRequest] class GuideRequest(BaseModel): repoName: str filePaths: List[str] class SearchRequest(BaseModel): query: str embeddings: Optional[Dict[str, List[float]]] = None repoName: Optional[str] = None class ChatRequest(BaseModel): query: str context: List[Dict[str, str]] repoName: str # 4. Core Logic Helpers def calculate_repo_health(total_vulns: int, avg_maint: float) -> int: base_score = avg_maint * 10 penalty = total_vulns * 8 final_score = base_score - penalty return int(max(10, min(100, final_score))) def parse_tree_to_list(raw_tree: str): """Parses text tree into JSON objects for Compose LazyColumn""" nodes = [] for line in raw_tree.strip().split('\n'): # Detect depth based on tree characters level = line.count('|') + (line.count(' ') // 2) # Clean the name name = re.sub(r'[|└├─]', '', line).strip() if name: nodes.append({ "name": name, "type": "file" if '.' in name else "folder", "level": level }) return nodes # 5. Endpoints @app.get("/") def health_check(): return { "status": "online", "model": "microsoft/codebert-base", "device": getattr(classifier, "device", "cpu"), "cached_repos": list(REPO_CACHE.keys()), } @app.get("/usage") def get_usage(): return llm_engine.get_usage_stats() @app.post("/classify") async def classify_file(request: FileRequest): try: result = classifier.predict(request.fileName, request.content) if request.repoName: if request.repoName not in REPO_CACHE: REPO_CACHE[request.repoName] = {} REPO_CACHE[request.repoName][request.fileName] = result["embedding"] return { "fileName": request.fileName, "layer": result["label"], "confidence": result["confidence"], "embedding": result["embedding"] } except Exception as e: logger.error(f"Classify failed: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/review-batch-code") async def review_batch_code(request: BatchReviewRequest): try: loop = asyncio.get_event_loop() # Helper function to process one file at a time def process_single_file(file_req): # We wrap it in a list so your existing service method still accepts it, # but it only processes 1 file per thread. return reviewer.service.review_batch_code([file_req]) # Spin up a ThreadPoolExecutor with up to 15 concurrent workers with ThreadPoolExecutor(max_workers=15) as executor: # Create a concurrent task for every file in the request tasks = [ loop.run_in_executor(executor, process_single_file, f) for f in request.files ] # asyncio.gather fires them all off at the exact same time raw_reviews = await asyncio.gather(*tasks, return_exceptions=True) # Clean up the results and handle any individual file failures gracefully valid_reviews = [] for i, result in enumerate(raw_reviews): if isinstance(result, Exception): # If we hit a rate limit, bubble it up immediately if "429" in str(result): raise HTTPException(status_code=429, detail="AI Quota Exceeded") # Otherwise, log the specific file error but don't crash the whole batch logger.error(f"Failed to analyze {request.files[i].fileName}: {result}") else: valid_reviews.append(result) return {"results": valid_reviews} except Exception as e: if isinstance(e, HTTPException): raise e traceback.print_exc() raise HTTPException(status_code=500, detail=str(e)) @app.post("/repo-dashboard-stats") async def get_dashboard_stats(request: BatchReviewRequest): try: loop = asyncio.get_event_loop() with ThreadPoolExecutor() as executor: raw_reviews = await loop.run_in_executor( executor, reviewer.service.review_batch_code, request.files ) total_vulns = 0 maint_scores = [] found_apis = set() api_regex = re.compile(r'(?:get|post|put|delete|patch)\([\'"]\/(.*?)[\'"]', re.IGNORECASE) for i, review in enumerate(raw_reviews): vulns = review.get("vulnerabilities", []) total_vulns += len(vulns) m_score = review.get("metrics", {}).get("maintainability", 8.0) maint_scores.append(m_score) content = request.files[i].content if content: matches = api_regex.findall(content) for match in matches: found_apis.add(f"/{match}") num_files = len(maint_scores) avg_maint = (sum(maint_scores) / num_files) if num_files > 0 else 0 health_score = calculate_repo_health(total_vulns, avg_maint) return { "repo_health": health_score, "health_label": "Excellent" if health_score > 85 else "Good" if health_score > 60 else "Critical", "security_issues": total_vulns, "performance_ratio": f"{int(avg_maint * 10)}%", "exposed_apis": list(found_apis), "total_files_processed": num_files, "average_maintainability": round(avg_maint, 1) } except Exception as e: if "429" in str(e): raise HTTPException(status_code=429, detail="Quota exceeded") logger.error(f"Dashboard stats failed: {e}") raise HTTPException(status_code=500, detail="Failed to sync dashboard metrics") @app.post("/analyze-file") async def analyze_file(request: FileRequest): try: result = classifier.predict(request.fileName, request.content) summary = classifier.generate_file_summary(request.content, request.fileName) tags = classifier.extract_tags(request.content, request.fileName) return { "fileName": request.fileName, "layer": result["label"], "summary": summary, "tags": tags, "embedding": result["embedding"], } except Exception as e: if "429" in str(e): raise HTTPException(status_code=429, detail="Limit Reached") raise HTTPException(status_code=500, detail=str(e)) @app.post("/semantic-search") async def semantic_search(request: SearchRequest): try: embeddings = request.embeddings if not embeddings and request.repoName and request.repoName in REPO_CACHE: embeddings = REPO_CACHE[request.repoName] if not embeddings: return {"results": []} results = classifier.semantic_search(request.query, embeddings) return {"results": results} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/chat") async def chat(request: ChatRequest): start_time = time.time() try: context_str = "" for item in request.context: context_str += f"--- FILE: {item['fileName']} ---\n{item['content']}\n\n" prompt = f""" You are "GitGud AI", an expert software architect. Repository: "{request.repoName}" CONTEXT: {context_str if request.context else "(NO CODE PROVIDED)"} USER QUESTION: {request.query} """ response = llm_engine.generate_text(prompt) return {"response": response, "status": "success"} except Exception as e: if "429" in str(e): # Return a structured error so Compose can show a nice UI return {"response": "⚠️ Daily limit reached. Try again in a bit!", "status": "quota_error"} raise HTTPException(status_code=500, detail=str(e)) @app.post("/generate-guide") async def generate_guide(request: GuideRequest): """ FIXED: Now parses the messy tree text into JSON for your Compose UI. """ try: markdown = guide_generator.generate_markdown(request.repoName, request.filePaths) # Extract the messy tree part and clean it tree_match = re.search(r"Project Structure\n\n(.*?)(?=\n\n|$)", markdown, re.S) structured_tree = [] if tree_match: structured_tree = parse_tree_to_list(tree_match.group(1)) return { "markdown": markdown, "structured_tree": structured_tree, "project_name": request.repoName } except Exception as e: if "429" in str(e): raise HTTPException(status_code=429, detail="AI Quota Exceeded") raise HTTPException(status_code=500, detail=str(e)) # 6. Application Entry Point if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) uvicorn.run(app, host="0.0.0.0", port=port)