Sonu Prasad
Optimize backend for production
4c6e0cc
"""
GitHub Companion API - Main FastAPI Application
A high-performance API for analyzing and chatting with GitHub repositories.
Optimized for Hugging Face Spaces deployment with multi-user support.
"""
import os
import uuid
import shutil
import pathlib
import tempfile
import asyncio
from typing import List, Literal
from contextlib import asynccontextmanager
from fastapi import FastAPI, BackgroundTasks, HTTPException, Query
from fastapi.responses import FileResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from git import Repo
from langchain_core.messages import AIMessage, HumanMessage
from ai_core import create_conversational_chain, query_with_context, embed_entire_repository
from shared import analysis_jobs, get_session, set_session, update_session
# ============================================================================
# Configuration
# ============================================================================
SESSIONS_BASE_DIR = pathlib.Path(tempfile.gettempdir()) / "repo_sessions"
# ============================================================================
# Lifespan Context Manager (Startup/Shutdown)
# ============================================================================
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Handle startup and shutdown events."""
# Startup: Ensure directories exist
SESSIONS_BASE_DIR.mkdir(exist_ok=True)
print(f"✅ GitHub Companion API started. Sessions dir: {SESSIONS_BASE_DIR}")
yield
# Shutdown: Cleanup could be added here if needed
print("🛑 GitHub Companion API shutting down.")
# ============================================================================
# FastAPI App Initialization
# ============================================================================
app = FastAPI(
title="GitHub Companion API",
description="API for high-performance analysis and contextual chat with GitHub repositories.",
version="5.0.0",
lifespan=lifespan
)
# CORS Configuration (allows all origins for Hugging Face Spaces)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ============================================================================
# Pydantic Models
# ============================================================================
class RepoRequest(BaseModel):
repo_url: str
class AnalysisResponse(BaseModel):
session_id: str
class StatusResponse(BaseModel):
session_id: str
status: Literal["pending", "cloning", "summarizing", "embedding_background", "completed", "failed"]
message: str | None = None
class FileDetail(BaseModel):
path: str
size_bytes: int
class AnalysisResult(BaseModel):
repo_url: str
directory_structure: List[FileDetail]
initial_summary: str
class FileContentResponse(BaseModel):
path: str
content: str
class ChatRequest(BaseModel):
query: str
pinned_files: List[str] = []
class ChatResponse(BaseModel):
answer: str
class ModifiedFile(BaseModel):
path: str
content: str
class DownloadRequest(BaseModel):
modified_files: List[ModifiedFile]
# ============================================================================
# Utility Functions
# ============================================================================
def is_text_file(file_path: str) -> bool:
"""Check if a file is readable as text."""
try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
f.read(512)
return True
except Exception:
return False
def initial_analysis_task(session_id: str, repo_url: str, background_tasks: BackgroundTasks):
"""
Background task to clone and analyze a repository.
This runs in a thread pool to avoid blocking the main event loop.
"""
session_repo_path = SESSIONS_BASE_DIR / session_id
try:
# Cleanup if exists
if session_repo_path.exists():
shutil.rmtree(session_repo_path)
SESSIONS_BASE_DIR.mkdir(exist_ok=True)
update_session(session_id, "status", "cloning")
# Clone repository (shallow clone for speed)
Repo.clone_from(repo_url, str(session_repo_path), depth=1)
update_session(session_id, "repo_path", str(session_repo_path))
repo_name = repo_url.split('/')[-1].replace('.git', '')
update_session(session_id, "repo_name", repo_name)
# Define ignore patterns
ignore_patterns = {'.git', '.gitignore', '__pycache__', 'node_modules', 'dist', 'build', '.venv', 'venv'}
all_file_details = []
key_file_paths_for_summary = []
all_text_file_paths_for_embedding = []
summary_candidate_names = {"readme.md", "package.json", "pyproject.toml", "requirements.txt", "pom.xml", "build.gradle", "cargo.toml"}
# Walk the repository
for root, dirs, files in os.walk(str(session_repo_path), topdown=True):
dirs[:] = [d for d in dirs if d not in ignore_patterns]
for name in files:
if name in ignore_patterns:
continue
file_path = os.path.join(root, name)
if not os.path.islink(file_path):
try:
relative_path = pathlib.Path(file_path).relative_to(session_repo_path).as_posix()
except ValueError:
relative_path = os.path.relpath(file_path, str(session_repo_path)).replace("\\", "/")
file_size = os.path.getsize(file_path)
all_file_details.append(FileDetail(path=relative_path, size_bytes=file_size))
if is_text_file(file_path):
all_text_file_paths_for_embedding.append(file_path)
if name.lower() in summary_candidate_names:
key_file_paths_for_summary.append(file_path)
update_session(session_id, "status", "summarizing")
# Initialize chat history
update_session(session_id, "chat_history", [])
# Create RAG chain with key files
rag_chain = create_conversational_chain(key_file_paths_for_summary, session_id)
if not rag_chain:
raise Exception("Failed to create initial AI chain.")
update_session(session_id, "rag_chain", rag_chain)
# Generate initial summary
job = get_session(session_id)
chat_history = job.get("chat_history", [])
summary_query = "Based on the provided files (like README, package.json, etc.), what is the primary purpose of this software project? Provide a concise, one-paragraph summary."
initial_summary = query_with_context(rag_chain, chat_history, summary_query, [], str(session_repo_path))
# Store result
result = AnalysisResult(
repo_url=repo_url,
directory_structure=sorted(all_file_details, key=lambda x: x.path),
initial_summary=initial_summary
)
update_session(session_id, "result", result)
update_session(session_id, "status", "embedding_background")
# Start background embedding
background_tasks.add_task(embed_entire_repository, session_id, all_text_file_paths_for_embedding)
except Exception as e:
update_session(session_id, "status", "failed")
update_session(session_id, "message", str(e))
print(f"❌ Analysis failed for session {session_id}: {e}")
# ============================================================================
# API Endpoints
# ============================================================================
@app.get("/")
def read_root():
"""Root endpoint with API info."""
return JSONResponse(content={
"message": "GitHub Companion Backend is Running",
"version": "5.0.0",
"docs": "/docs"
})
@app.get("/health")
def health_check():
"""Health check endpoint for Hugging Face Spaces."""
return JSONResponse(content={"status": "healthy"})
@app.post("/analyze", response_model=AnalysisResponse, status_code=202)
def submit_analysis(request: RepoRequest, background_tasks: BackgroundTasks):
"""Submit a repository for analysis."""
session_id = str(uuid.uuid4())
set_session(session_id, {"status": "pending"})
background_tasks.add_task(initial_analysis_task, session_id, request.repo_url, background_tasks)
return AnalysisResponse(session_id=session_id)
@app.get("/status/{session_id}", response_model=StatusResponse)
def get_analysis_status(session_id: str):
"""Get the status of an analysis job."""
job = get_session(session_id)
if not job:
raise HTTPException(status_code=404, detail="Session ID not found.")
# Check if embedding is complete
if job.get("status") == "embedding_background" and job.get("embedding_complete"):
update_session(session_id, "status", "completed")
job["status"] = "completed"
return StatusResponse(session_id=session_id, status=job["status"], message=job.get("message"))
@app.get("/result/{session_id}", response_model=AnalysisResult)
def get_analysis_result(session_id: str):
"""Get the analysis result for a completed job."""
job = get_session(session_id)
if not job or job.get("status") not in ["embedding_background", "completed"]:
raise HTTPException(status_code=400, detail="Job not found or not ready.")
return job["result"]
@app.get("/file-content/{session_id}", response_model=FileContentResponse)
def get_file_content(session_id: str, file_path: str = Query(..., alias="path")):
"""Get the content of a specific file in the repository."""
job = get_session(session_id)
if not job or "repo_path" not in job:
raise HTTPException(status_code=404, detail="Session not found.")
repo_base_path = pathlib.Path(job["repo_path"]).resolve()
requested_file_path = (repo_base_path / file_path).resolve()
# Security: Prevent path traversal
if not requested_file_path.is_relative_to(repo_base_path):
raise HTTPException(status_code=403, detail="Access denied.")
if not requested_file_path.is_file():
raise HTTPException(status_code=404, detail="File not found.")
try:
content = requested_file_path.read_text(encoding="utf-8")
return FileContentResponse(path=file_path, content=content)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error reading file: {str(e)}")
@app.post("/chat/{session_id}", response_model=ChatResponse)
def chat_with_repo(session_id: str, request: ChatRequest):
"""Chat with the AI about the repository."""
job = get_session(session_id)
if not job or "rag_chain" not in job:
raise HTTPException(status_code=404, detail="Chat session not ready.")
rag_chain = job["rag_chain"]
chat_history = job.get("chat_history", [])
repo_path = job["repo_path"]
answer = query_with_context(rag_chain, chat_history, request.query, request.pinned_files, repo_path)
return ChatResponse(answer=answer)
@app.post("/download-zip/{session_id}")
async def download_zip(session_id: str, request: DownloadRequest, background_tasks: BackgroundTasks):
"""Download the repository as a ZIP file with any modifications applied."""
job = get_session(session_id)
if not job or "repo_path" not in job:
raise HTTPException(status_code=404, detail="Session not found.")
repo_base_path = pathlib.Path(job["repo_path"]).resolve()
repo_name = job.get("repo_name", session_id)
temp_zip_dir = pathlib.Path(tempfile.gettempdir()) / "temp_zips"
# Apply modifications
for modified_file in request.modified_files:
file_to_update = (repo_base_path / modified_file.path).resolve()
if not file_to_update.is_relative_to(repo_base_path):
continue
file_to_update.parent.mkdir(parents=True, exist_ok=True)
file_to_update.write_text(modified_file.content, encoding="utf-8")
# Create ZIP
temp_zip_dir.mkdir(exist_ok=True)
zip_path_base = temp_zip_dir / f"{repo_name}-{session_id}"
zip_path_final = shutil.make_archive(str(zip_path_base), 'zip', str(repo_base_path))
# Cleanup ZIP after download
background_tasks.add_task(os.remove, zip_path_final)
return FileResponse(
path=zip_path_final,
media_type='application/zip',
filename=f'{repo_name}-modified.zip'
)