GitRecap / server /websockets.py
github-actions[bot]
Deploy app/api to HF Space
d449470
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
import json
from typing import Literal, Optional
import asyncio
from services.prompts import (
PR_DESCRIPTION_SYSTEM,
SELECT_QUIRKY_REMARK_SYSTEM,
SYSTEM,
RELEASE_NOTES_SYSTEM,
quirky_remarks,
)
from services.llm_service import (
get_random_quirky_remarks,
run_concurrent_tasks,
get_llm,
)
from aicore.const import SPECIAL_TOKENS, STREAM_END_TOKEN
router = APIRouter()
# WebSocket connection storage
active_connections = {}
active_histories = {}
TRIGGER_PROMPT = """
Consider the following history of actionables from Git and return me the summary with N = '{N}' bullet points:
{ACTIONS}
"""
TRIGGER_RELEASE_PROMPT = """
Consider the following history of actionables from Git and the previous Release Notes (if available).
Generate me the next Release Notes based on the new Git Actionables matching the format of the previous releases:
{ACTIONS}
"""
TRIGGER_PULL_REQUEST_PROMPT = """
You will now receive a list of commit messages between two branches.
Using the system instructions provided above, generate a clear, concise, and professional **Pull Request Description** summarizing all changes from branch `{SRC}` to be merged into `{TARGET}`.
Commits:
{COMMITS}
Please follow these steps:
1. Read and analyze the commit messages.
2. Identify and group related changes under appropriate markdown headers (e.g., Features, Bug Fixes, Improvements, Documentation, Tests).
3. Write a short **summary paragraph** explaining the overall purpose of this pull request.
4. Format the final output as a complete markdown-formatted PR description, ready to paste into GitHub.
Begin your response directly with the formatted PR description—no extra commentary or explanation.
"""
@router.websocket("/ws/{session_id}/{action_type}")
async def websocket_endpoint(
websocket: WebSocket,
session_id: Optional[str] = None,
action_type: Literal["recap", "release", "pull_request"] = "recap"
):
"""
WebSocket endpoint for real-time LLM operations.
Handles three action types:
- recap: Generate commit summaries with quirky remarks
- release: Generate release notes based on git history
- pull_request: Generate PR descriptions from commit diffs
Args:
websocket: WebSocket connection instance
session_id: Session identifier for LLM and fetcher management
action_type: Type of operation to perform
Raises:
HTTPException: If action_type is invalid
"""
await websocket.accept()
# Select appropriate system prompt based on action type
if action_type == "recap":
QUIRKY_SYSTEM = SELECT_QUIRKY_REMARK_SYSTEM.format(
examples=json.dumps(get_random_quirky_remarks(quirky_remarks), indent=4)
)
system = [SYSTEM, QUIRKY_SYSTEM]
elif action_type == "release":
system = RELEASE_NOTES_SYSTEM
elif action_type == "pull_request":
system = PR_DESCRIPTION_SYSTEM
else:
raise HTTPException(status_code=404, detail="Invalid action type")
# Store the active WebSocket connection
active_connections[session_id] = websocket
# Initialize LLM session
llm = get_llm(session_id)
try:
while True:
# Receive message from client
message = await websocket.receive_text()
msg_json = json.loads(message)
message_content = msg_json.get("actions")
N = msg_json.get("n", 5)
src_branch = msg_json.get("src")
target_branch = msg_json.get("target")
# Validate inputs
assert int(N) <= 15, "N must be <= 15"
assert message_content, "Message content is required"
# Build history/prompt based on action type
if action_type == "recap":
history = [
TRIGGER_PROMPT.format(
N=N,
ACTIONS=message_content
)
]
elif action_type == "release":
history = [
TRIGGER_RELEASE_PROMPT.format(ACTIONS=message_content)
]
elif action_type == "pull_request":
history = [
TRIGGER_PULL_REQUEST_PROMPT.format(
SRC=src_branch,
TARGET=target_branch,
COMMITS=message_content)
]
# Stream LLM response back to client
response = []
async for chunk in run_concurrent_tasks(
llm,
message=history,
system_prompt=system
):
if chunk == STREAM_END_TOKEN:
await websocket.send_text(json.dumps({"chunk": chunk}))
break
elif chunk in SPECIAL_TOKENS:
continue
await websocket.send_text(json.dumps({"chunk": chunk}))
response.append(chunk)
# Store response in history for potential follow-up
history.append("".join(response))
except WebSocketDisconnect:
# Clean up connection on disconnect
if session_id in active_connections:
del active_connections[session_id]
except AssertionError as e:
# Handle validation errors
if session_id in active_connections:
await websocket.send_text(json.dumps({"error": f"Validation error: {str(e)}"}))
del active_connections[session_id]
except Exception as e:
# Handle unexpected errors
if session_id in active_connections:
await websocket.send_text(json.dumps({"error": str(e)}))
del active_connections[session_id]
def close_websocket_connection(session_id: str):
"""
Clean up and close the active WebSocket connection associated with the given session_id.
This function is called during session expiration to ensure proper cleanup
of WebSocket resources.
Args:
session_id: The session identifier whose WebSocket connection should be closed
"""
websocket = active_connections.pop(session_id, None)
if websocket:
asyncio.create_task(websocket.close())