Spaces:
Running
Running
| import os | |
| import re | |
| import logging | |
| import traceback | |
| import time | |
| import asyncio | |
| from typing import List, Optional, Dict | |
| from concurrent.futures import ThreadPoolExecutor | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, HTTPException, status | |
| from pydantic import BaseModel | |
| import uvicorn | |
| # Load environment variables | |
| load_dotenv() | |
| # Import internal services | |
| from app.predictor import classifier, guide_generator, reviewer | |
| from app.core.model_loader import llm_engine | |
| # 1. Setup Logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # 2. Initialize FastAPI | |
| app = FastAPI(title="GitGud AI Service") | |
| # Global embedding cache | |
| REPO_CACHE: Dict[str, Dict[str, List[float]]] = {} | |
| # 3. Data Models | |
| class FileRequest(BaseModel): | |
| fileName: str | |
| content: Optional[str] = None | |
| repoName: Optional[str] = None | |
| class BatchReviewRequest(BaseModel): | |
| files: List[FileRequest] | |
| class GuideRequest(BaseModel): | |
| repoName: str | |
| filePaths: List[str] | |
| class SearchRequest(BaseModel): | |
| query: str | |
| embeddings: Optional[Dict[str, List[float]]] = None | |
| repoName: Optional[str] = None | |
| class ChatRequest(BaseModel): | |
| query: str | |
| context: List[Dict[str, str]] | |
| repoName: str | |
| # 4. Core Logic Helpers | |
| def calculate_repo_health(total_vulns: int, avg_maint: float) -> int: | |
| base_score = avg_maint * 10 | |
| penalty = total_vulns * 8 | |
| final_score = base_score - penalty | |
| return int(max(10, min(100, final_score))) | |
| def parse_tree_to_list(raw_tree: str): | |
| """Parses text tree into JSON objects for Compose LazyColumn""" | |
| nodes = [] | |
| for line in raw_tree.strip().split('\n'): | |
| # Detect depth based on tree characters | |
| level = line.count('|') + (line.count(' ') // 2) | |
| # Clean the name | |
| name = re.sub(r'[|└├─]', '', line).strip() | |
| if name: | |
| nodes.append({ | |
| "name": name, | |
| "type": "file" if '.' in name else "folder", | |
| "level": level | |
| }) | |
| return nodes | |
| # 5. Endpoints | |
| def health_check(): | |
| return { | |
| "status": "online", | |
| "model": "microsoft/codebert-base", | |
| "device": getattr(classifier, "device", "cpu"), | |
| "cached_repos": list(REPO_CACHE.keys()), | |
| } | |
| def get_usage(): | |
| return llm_engine.get_usage_stats() | |
| async def classify_file(request: FileRequest): | |
| try: | |
| result = classifier.predict(request.fileName, request.content) | |
| if request.repoName: | |
| if request.repoName not in REPO_CACHE: | |
| REPO_CACHE[request.repoName] = {} | |
| REPO_CACHE[request.repoName][request.fileName] = result["embedding"] | |
| return { | |
| "fileName": request.fileName, | |
| "layer": result["label"], | |
| "confidence": result["confidence"], | |
| "embedding": result["embedding"] | |
| } | |
| except Exception as e: | |
| logger.error(f"Classify failed: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def review_batch_code(request: BatchReviewRequest): | |
| try: | |
| loop = asyncio.get_event_loop() | |
| # Helper function to process one file at a time | |
| def process_single_file(file_req): | |
| # We wrap it in a list so your existing service method still accepts it, | |
| # but it only processes 1 file per thread. | |
| return reviewer.service.review_batch_code([file_req]) | |
| # Spin up a ThreadPoolExecutor with up to 15 concurrent workers | |
| with ThreadPoolExecutor(max_workers=15) as executor: | |
| # Create a concurrent task for every file in the request | |
| tasks = [ | |
| loop.run_in_executor(executor, process_single_file, f) | |
| for f in request.files | |
| ] | |
| # asyncio.gather fires them all off at the exact same time | |
| raw_reviews = await asyncio.gather(*tasks, return_exceptions=True) | |
| # Clean up the results and handle any individual file failures gracefully | |
| valid_reviews = [] | |
| for i, result in enumerate(raw_reviews): | |
| if isinstance(result, Exception): | |
| # If we hit a rate limit, bubble it up immediately | |
| if "429" in str(result): | |
| raise HTTPException(status_code=429, detail="AI Quota Exceeded") | |
| # Otherwise, log the specific file error but don't crash the whole batch | |
| logger.error(f"Failed to analyze {request.files[i].fileName}: {result}") | |
| else: | |
| valid_reviews.append(result) | |
| return {"results": valid_reviews} | |
| except Exception as e: | |
| if isinstance(e, HTTPException): | |
| raise e | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_dashboard_stats(request: BatchReviewRequest): | |
| try: | |
| loop = asyncio.get_event_loop() | |
| with ThreadPoolExecutor() as executor: | |
| raw_reviews = await loop.run_in_executor( | |
| executor, reviewer.service.review_batch_code, request.files | |
| ) | |
| total_vulns = 0 | |
| maint_scores = [] | |
| found_apis = set() | |
| api_regex = re.compile(r'(?:get|post|put|delete|patch)\([\'"]\/(.*?)[\'"]', re.IGNORECASE) | |
| for i, review in enumerate(raw_reviews): | |
| vulns = review.get("vulnerabilities", []) | |
| total_vulns += len(vulns) | |
| m_score = review.get("metrics", {}).get("maintainability", 8.0) | |
| maint_scores.append(m_score) | |
| content = request.files[i].content | |
| if content: | |
| matches = api_regex.findall(content) | |
| for match in matches: | |
| found_apis.add(f"/{match}") | |
| num_files = len(maint_scores) | |
| avg_maint = (sum(maint_scores) / num_files) if num_files > 0 else 0 | |
| health_score = calculate_repo_health(total_vulns, avg_maint) | |
| return { | |
| "repo_health": health_score, | |
| "health_label": "Excellent" if health_score > 85 else "Good" if health_score > 60 else "Critical", | |
| "security_issues": total_vulns, | |
| "performance_ratio": f"{int(avg_maint * 10)}%", | |
| "exposed_apis": list(found_apis), | |
| "total_files_processed": num_files, | |
| "average_maintainability": round(avg_maint, 1) | |
| } | |
| except Exception as e: | |
| if "429" in str(e): | |
| raise HTTPException(status_code=429, detail="Quota exceeded") | |
| logger.error(f"Dashboard stats failed: {e}") | |
| raise HTTPException(status_code=500, detail="Failed to sync dashboard metrics") | |
| async def analyze_file(request: FileRequest): | |
| try: | |
| result = classifier.predict(request.fileName, request.content) | |
| summary = classifier.generate_file_summary(request.content, request.fileName) | |
| tags = classifier.extract_tags(request.content, request.fileName) | |
| return { | |
| "fileName": request.fileName, | |
| "layer": result["label"], | |
| "summary": summary, | |
| "tags": tags, | |
| "embedding": result["embedding"], | |
| } | |
| except Exception as e: | |
| if "429" in str(e): | |
| raise HTTPException(status_code=429, detail="Limit Reached") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def semantic_search(request: SearchRequest): | |
| try: | |
| embeddings = request.embeddings | |
| if not embeddings and request.repoName and request.repoName in REPO_CACHE: | |
| embeddings = REPO_CACHE[request.repoName] | |
| if not embeddings: | |
| return {"results": []} | |
| results = classifier.semantic_search(request.query, embeddings) | |
| return {"results": results} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def chat(request: ChatRequest): | |
| start_time = time.time() | |
| try: | |
| context_str = "" | |
| for item in request.context: | |
| context_str += f"--- FILE: {item['fileName']} ---\n{item['content']}\n\n" | |
| prompt = f""" | |
| You are "GitGud AI", an expert software architect. | |
| Repository: "{request.repoName}" | |
| CONTEXT: {context_str if request.context else "(NO CODE PROVIDED)"} | |
| USER QUESTION: {request.query} | |
| """ | |
| response = llm_engine.generate_text(prompt) | |
| return {"response": response, "status": "success"} | |
| except Exception as e: | |
| if "429" in str(e): | |
| # Return a structured error so Compose can show a nice UI | |
| return {"response": "⚠️ Daily limit reached. Try again in a bit!", "status": "quota_error"} | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_guide(request: GuideRequest): | |
| """ | |
| FIXED: Now parses the messy tree text into JSON for your Compose UI. | |
| """ | |
| try: | |
| markdown = guide_generator.generate_markdown(request.repoName, request.filePaths) | |
| # Extract the messy tree part and clean it | |
| tree_match = re.search(r"Project Structure\n\n(.*?)(?=\n\n|$)", markdown, re.S) | |
| structured_tree = [] | |
| if tree_match: | |
| structured_tree = parse_tree_to_list(tree_match.group(1)) | |
| return { | |
| "markdown": markdown, | |
| "structured_tree": structured_tree, | |
| "project_name": request.repoName | |
| } | |
| except Exception as e: | |
| if "429" in str(e): | |
| raise HTTPException(status_code=429, detail="AI Quota Exceeded") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # 6. Application Entry Point | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) |