Spaces:
Running
Running
| from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect | |
| import json | |
| from typing import Literal, Optional | |
| import asyncio | |
| from services.prompts import ( | |
| PR_DESCRIPTION_SYSTEM, | |
| SELECT_QUIRKY_REMARK_SYSTEM, | |
| SYSTEM, | |
| RELEASE_NOTES_SYSTEM, | |
| quirky_remarks, | |
| ) | |
| from services.llm_service import ( | |
| get_random_quirky_remarks, | |
| run_concurrent_tasks, | |
| get_llm, | |
| ) | |
| from aicore.const import SPECIAL_TOKENS, STREAM_END_TOKEN | |
| router = APIRouter() | |
| # WebSocket connection storage | |
| active_connections = {} | |
| active_histories = {} | |
| TRIGGER_PROMPT = """ | |
| Consider the following history of actionables from Git and return me the summary with N = '{N}' bullet points: | |
| {ACTIONS} | |
| """ | |
| TRIGGER_RELEASE_PROMPT = """ | |
| Consider the following history of actionables from Git and the previous Release Notes (if available). | |
| Generate me the next Release Notes based on the new Git Actionables matching the format of the previous releases: | |
| {ACTIONS} | |
| """ | |
| TRIGGER_PULL_REQUEST_PROMPT = """ | |
| You will now receive a list of commit messages between two branches. | |
| Using the system instructions provided above, generate a clear, concise, and professional **Pull Request Description** summarizing all changes from branch `{SRC}` to be merged into `{TARGET}`. | |
| Commits: | |
| {COMMITS} | |
| Please follow these steps: | |
| 1. Read and analyze the commit messages. | |
| 2. Identify and group related changes under appropriate markdown headers (e.g., Features, Bug Fixes, Improvements, Documentation, Tests). | |
| 3. Write a short **summary paragraph** explaining the overall purpose of this pull request. | |
| 4. Format the final output as a complete markdown-formatted PR description, ready to paste into GitHub. | |
| Begin your response directly with the formatted PR description—no extra commentary or explanation. | |
| """ | |
| async def websocket_endpoint( | |
| websocket: WebSocket, | |
| session_id: Optional[str] = None, | |
| action_type: Literal["recap", "release", "pull_request"] = "recap" | |
| ): | |
| """ | |
| WebSocket endpoint for real-time LLM operations. | |
| Handles three action types: | |
| - recap: Generate commit summaries with quirky remarks | |
| - release: Generate release notes based on git history | |
| - pull_request: Generate PR descriptions from commit diffs | |
| Args: | |
| websocket: WebSocket connection instance | |
| session_id: Session identifier for LLM and fetcher management | |
| action_type: Type of operation to perform | |
| Raises: | |
| HTTPException: If action_type is invalid | |
| """ | |
| await websocket.accept() | |
| # Select appropriate system prompt based on action type | |
| if action_type == "recap": | |
| QUIRKY_SYSTEM = SELECT_QUIRKY_REMARK_SYSTEM.format( | |
| examples=json.dumps(get_random_quirky_remarks(quirky_remarks), indent=4) | |
| ) | |
| system = [SYSTEM, QUIRKY_SYSTEM] | |
| elif action_type == "release": | |
| system = RELEASE_NOTES_SYSTEM | |
| elif action_type == "pull_request": | |
| system = PR_DESCRIPTION_SYSTEM | |
| else: | |
| raise HTTPException(status_code=404, detail="Invalid action type") | |
| # Store the active WebSocket connection | |
| active_connections[session_id] = websocket | |
| # Initialize LLM session | |
| llm = get_llm(session_id) | |
| try: | |
| while True: | |
| # Receive message from client | |
| message = await websocket.receive_text() | |
| msg_json = json.loads(message) | |
| message_content = msg_json.get("actions") | |
| N = msg_json.get("n", 5) | |
| src_branch = msg_json.get("src") | |
| target_branch = msg_json.get("target") | |
| # Validate inputs | |
| assert int(N) <= 15, "N must be <= 15" | |
| assert message_content, "Message content is required" | |
| # Build history/prompt based on action type | |
| if action_type == "recap": | |
| history = [ | |
| TRIGGER_PROMPT.format( | |
| N=N, | |
| ACTIONS=message_content | |
| ) | |
| ] | |
| elif action_type == "release": | |
| history = [ | |
| TRIGGER_RELEASE_PROMPT.format(ACTIONS=message_content) | |
| ] | |
| elif action_type == "pull_request": | |
| history = [ | |
| TRIGGER_PULL_REQUEST_PROMPT.format( | |
| SRC=src_branch, | |
| TARGET=target_branch, | |
| COMMITS=message_content) | |
| ] | |
| # Stream LLM response back to client | |
| response = [] | |
| async for chunk in run_concurrent_tasks( | |
| llm, | |
| message=history, | |
| system_prompt=system | |
| ): | |
| if chunk == STREAM_END_TOKEN: | |
| await websocket.send_text(json.dumps({"chunk": chunk})) | |
| break | |
| elif chunk in SPECIAL_TOKENS: | |
| continue | |
| await websocket.send_text(json.dumps({"chunk": chunk})) | |
| response.append(chunk) | |
| # Store response in history for potential follow-up | |
| history.append("".join(response)) | |
| except WebSocketDisconnect: | |
| # Clean up connection on disconnect | |
| if session_id in active_connections: | |
| del active_connections[session_id] | |
| except AssertionError as e: | |
| # Handle validation errors | |
| if session_id in active_connections: | |
| await websocket.send_text(json.dumps({"error": f"Validation error: {str(e)}"})) | |
| del active_connections[session_id] | |
| except Exception as e: | |
| # Handle unexpected errors | |
| if session_id in active_connections: | |
| await websocket.send_text(json.dumps({"error": str(e)})) | |
| del active_connections[session_id] | |
| def close_websocket_connection(session_id: str): | |
| """ | |
| Clean up and close the active WebSocket connection associated with the given session_id. | |
| This function is called during session expiration to ensure proper cleanup | |
| of WebSocket resources. | |
| Args: | |
| session_id: The session identifier whose WebSocket connection should be closed | |
| """ | |
| websocket = active_connections.pop(session_id, None) | |
| if websocket: | |
| asyncio.create_task(websocket.close()) |