nikhmr1235's picture
update app.py to use refactored files
2fdd3bc verified
import os
import hmac
import hashlib
from typing import Optional
from dotenv import load_dotenv
from fastapi import FastAPI, Request, HTTPException, Header, status
from pydantic import BaseModel, Field
from src.langgraph_logic.state import PRReviewState
from src.langgraph_logic.graph import create_graph
from langgraph.graph import StateGraph, END
import sqlite3 # Import sqlite3 for direct connection if needed for older versions
from langgraph.checkpoint.sqlite import SqliteSaver
import os
import logging
import uuid
from github import Github # pip install PyGithub
from urllib.parse import quote_plus
# Load environment variables from .env file
load_dotenv()
app = FastAPI(
title="GitHub PR Reviewer Bot Webhook Listener",
description="Listens for GitHub PR events and initiates Langgraph workflows."
)
# --- Configuration ---
GITHUB_WEBHOOK_SECRET = os.getenv("GITHUB_WEBHOOK_SECRET")
#GITHUB_WEBHOOK_SECRET ="d62UMC1iQ6n5PHq9w9bSvPHiWXBqhKX4"
if not GITHUB_WEBHOOK_SECRET:
raise ValueError("GITHUB_WEBHOOK_SECRET environment variable not set. Please create a .env file.")
require_human_approval_from_env = os.getenv("require_human_approval")
if not require_human_approval_from_env:
print("require_human_approval key not found in environment variables.")
print(f"Using require_human_approval key: {require_human_approval_from_env}")
# --- Build the Graph ---
graph_builder = create_graph()
# Define the path for the SQLite database file
#SQLITE_DB_PATH = "langgraph_checkpoints.sqlite"
#SQLITE_DB_PATH = os.path.join(os.getcwd(), "langgraph_checkpoints.sqlite")
import tempfile
SQLITE_DB_PATH = tempfile.gettempdir() + '/langgraph_checkpoints.sqlite'
# --- Checkpointer and Graph Compilation ---
global_memory_saver = None # Initialize to None
try:
# Attempt to connect to SQLite
# Ensure the directory exists (though /data usually exists)
#os.makedirs(os.path.dirname(SQLITE_DB_PATH), exist_ok=True)
# Use check_same_thread=False for Gradio/web apps that might access the DB from different threads
conn = sqlite3.connect(SQLITE_DB_PATH, check_same_thread=False)
global_memory_saver = SqliteSaver(conn=conn)
print(f"SqliteSaver initialized successfully: {type(global_memory_saver)}") # ADD THIS LINE
except Exception as e:
print(f"Error initializing SqliteSaver connection: {e}")
# If an error occurs here, global_memory_saver will remain None
if global_memory_saver:
global_graph = graph_builder.compile(
checkpointer=global_memory_saver,
interrupt_before=["update_review_body_based_on_human_input_node"]
)
print("Graph compiled successfully.") # ADD THIS LINE
else:
print("SqliteSaver not initialized, graph compilation skipped.")
global_graph = None # Handle the case where checkpointer couldn't be initialized
#display(Image(global_graph.get_graph(xray=True).draw_mermaid_png()))
# --- Pydantic Models for Payload Parsing ---
class Repository(BaseModel):
id: int
full_name: str
html_url: str
clone_url: str
class PullRequestUser(BaseModel):
login: str
id: int
type: str
class PullRequest(BaseModel):
id: int
url: str
html_url: str
diff_url: str # This is the changelist URL (diff URL)
state: str
title: str
user: PullRequestUser
base: dict # Contains info about the base branch/repo
head: dict # Contains info about the head branch/repo
class GitHubWebhookPayload(BaseModel):
action: str
pull_request: PullRequest
repository: Repository
sender: PullRequestUser
# --- Helper Function for Signature Verification ---
def verify_signature(x_hub_signature_256: str, payload_body: bytes) -> bool:
"""
Verifies the GitHub webhook signature.
"""
if not GITHUB_WEBHOOK_SECRET:
return False # Should be caught by the initial check, but for safety
# GitHub sends x-hub-signature-256 in the format 'sha256=<signature>'
# We need to extract just the signature part
expected_signature = x_hub_signature_256.split('sha256=')[1]
# Calculate the HMAC SHA256 signature
mac = hmac.new(
GITHUB_WEBHOOK_SECRET.encode('utf-8'),
msg=payload_body,
digestmod=hashlib.sha256
)
calculated_signature = mac.hexdigest()
# Use hmac.compare_digest for constant-time comparison to prevent timing attacks
return hmac.compare_digest(calculated_signature, expected_signature)
GRADIO_SPACE_BASE_URL = "https://nikhmr1235-pr-reviewer-gradio-ui.hf.space" # Your Gradio Space URL
# --- Webhook Endpoint ---
@app.post("/webhook")
async def github_webhook(
request: Request,
x_github_event: str = Header(..., alias="X-GitHub-Event"),
x_hub_signature_256: str = Header(..., alias="X-Hub-Signature-256"),
):
"""
Receives and processes GitHub webhook events for new Pull Request creation.
"""
# Ensure it's a pull_request event first
print("/webhook triggered successfully")
# Handle ping event separately
if x_github_event == "ping":
print("Received GitHub 'ping' event. Webhook is active.")
return {"message": "pong"} # Or any success message
if x_github_event != "pull_request":
print(f"Received non-pull_request event: {x_github_event}. Skipping.")
return {"message": f"Event type '{x_github_event}' ignored."}
payload_body = await request.body()
# 1. Signature Verification
if not verify_signature(x_hub_signature_256, payload_body):
print("Webhook signature verification failed!")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid signature"
)
print("Webhook signature verified successfully.")
# 2. Payload Parsing
try:
payload = GitHubWebhookPayload.model_validate_json(payload_body)
print("Payload parsed successfully.")
except Exception as e:
print(f"Error parsing JSON payload: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid JSON payload: {e}"
)
# --- Filter for only 'opened' PR actions ---
if payload.action != "opened":
print(f"Received PR event with action '{payload.action}'. Only 'opened' events are processed. Skipping.")
return {"message": f"PR event action '{payload.action}' ignored."}
# If we reach here, it's an 'opened' pull_request event
print(f"New PR creation event detected! Action: {payload.action}")
# Extract crucial information
#pr_id = payload.pull_request.id
#pr_id = payload.pull_request.number
repo_id = payload.repository.id
pr_repo_name = payload.repository.full_name
changelist_url = payload.pull_request.diff_url
pr_id = int(changelist_url.split('/pull/')[1].split('.')[0])
repository_clone_url = payload.repository.clone_url
pr_action = payload.action # This will now always be "opened"
pr_title = payload.pull_request.title
print(f"Received PR event: Action={pr_action}, PR ID={pr_id}, Repo ID={repo_id}")
print(f"Changelist URL (Diff URL): {changelist_url}")
print(f"Repository Clone URL: {repository_clone_url}")
print(f"PR Title: {pr_title}")
print(f"pr_repo_name: {pr_repo_name} ")
# each graph invocation will have unique thread_id, it can be used to RESUME graph post human feedback.
graph_thread_id = str(uuid.uuid4())
print(f"\n--- Starting New Graph Execution with thread_id: {graph_thread_id} ---")
if not global_graph:
print("Error: Graph not initialized. Check server logs.")
return
print("global_graph is initialised...")
initial_state_data = {
"pr_id": pr_id,
"repo_name": pr_repo_name,
"pr_title": None, # Optional field, explicitly set to None if not available
"code_diff": None, # Optional field, explicitly set to None if not available
"file_contents": None, # Optional field, explicitly set to None if not available
"llm_markdown_review": None, # Optional field, explicitly set to None if not available
"parsed_llm_review_data": None, # Optional field, explicitly set to None if not available
"main_comment_body": None,
"review_status": "initiated", # REQUIRED: Set an initial valid Literal value
"require_human_approval": require_human_approval_from_env, # REQUIRED: Set an initial boolean value
"human_approval_status": False, # Optional field
"human_feedback_message": None, # Optional field
"original_review_id": None, # Optional field
"original_review_url": None, # Optional field
"final_review_id": None, # Optional field
"final_review_url": None, # Optional field
"last_error": None, # Optional field
}
initial_state = PRReviewState(**initial_state_data)
print("Graph is being streamed....")
try:
for s in global_graph.stream(initial_state, {"configurable": {"thread_id": graph_thread_id}}):
if "__end__" in s:
break
elif "__interrupt__" in s:
print(f"Graph interrupted BEFORE {s.get('__interrupt__', 'Unknown')} node.")
break
else:
pass
print("printing state from web_hook since graph is paused at this point")
current_state_snapshot = global_graph.get_state({"configurable": {"thread_id": graph_thread_id}})
current_state = current_state_snapshot.values
require_human_approval = current_state['require_human_approval']
repo_name = current_state['repo_name']
pr_id = current_state['pr_id']
review_id = current_state['original_review_id']
review_url = current_state['original_review_url']
print(f"retrieving review ID and URL from paused graph state")
print(f"original_review_id : {review_id}\t\t\t original_review_url:{review_url} ")
if not require_human_approval:
print(f"Graph excution has completed with automated review:{review_url}\n\n Note: No human approval is requested as per the state_require_human_approval:{require_human_approval}")
return
thread_id = graph_thread_id
gradio_review_url = f"{GRADIO_SPACE_BASE_URL}?review_id={review_id}&thread_id={thread_id}"
print(f"gradio_review_url:{gradio_review_url}\nthread_id:{thread_id}")
comment_body = f"""
🤖 **Human Review Required!** 🤖
The automated review for this Pull Request has paused, awaiting your decision.
Please visit the following link to **Approve or Reject** and provide feedback:
[Click here to review PR #{pr_id}]({gradio_review_url})
Once you submit your decision, the automated workflow will resume. Thank you!
"""
try:
g = Github(git_hub_token)
repo = g.get_repo(repo_name)
# In GitHub API, PR comments are typically added as "issue comments"
# because PRs are a type of issue internally.
issue = repo.get_issue(pr_id)
issue.create_comment(comment_body)
print(f"Posted comment on PR #{pr_id} in {pr_repo_name}.")
except Exception as e:
print(f"Error posting comment to GitHub PR #{pr_id}: {e}")
# Consider logging this error or retrying
#output_message = current_state_snapshot.values.get("greeting_message", "No greeting yet.")
#output_message += "\n\n" + "Please type your response in the 'Your Input' box and click 'Resume Graph'."
#return (output_message, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=False), new_thread_id)
except Exception as e:
return (f"An error occurred during graph start: {e}")
'''
initial_state_data = {
"pr_id": pr_id,
"repo_name": pr_repo_name,
}
initial_state = PRReviewState(**initial_state_data)
#output = graph.invoke(initial_state)
print(f"completed graph execution with result:{output['last_error']}")
'''
return
@app.get("/")
async def read_root():
return {"message": "PR Reviewer Bot is running!"}
# Add this import at the top of your app.py
from pydantic import BaseModel
# Define a Pydantic model for the human input payload
class HumanInputPayload(BaseModel):
review_id: int
approval_status: str # "approved" or "rejected"
feedback_message: Optional[str] = None # Optional feedback
thread_id: str
@app.post("/resume-review")
async def resume_review(payload: HumanInputPayload):
"""
Receives human input (approval/rejection and feedback) to resume the PR review workflow.
"""
print(f"Received human input for Review ID: {payload.review_id}")
print(f"Status: {payload.approval_status}")
print(f"Feedback: {payload.feedback_message}")
print(f"Thread ID: {payload.thread_id}")
if not global_graph or not global_memory_saver:
print("Error: Graph or checkpointer not initialized. Check server logs.")
return
try:
config = {"configurable": {"thread_id": payload.thread_id}}
current_state_snapshot = global_graph.get_state(config)
current_state_values = current_state_snapshot.values
if payload.approval_status.lower() == "rejected":
current_state_values["human_approval_status"] = False
else:
current_state_values["human_approval_status"] = True
current_state_values["human_feedback_message"] = payload.feedback_message
# Update the state directly
global_graph.update_state(config, current_state_values)
print("graph state updated successfuly with human feedback")
print("RESUMING the graph execution")
for s in global_graph.stream(None, config):
if "__end__" in s:
break
else:
pass
final_state_snapshot = global_graph.get_state(config)
final_state_values = final_state_snapshot.values
final_review_id = final_state_values["final_review_id"]
final_review_url = final_state_values["final_review_url"]
last_error = final_state_values["last_error"]
if last_error is not None:
print(f"Encountered error:{last_error} during graph execution")
return
else:
print(f"final_review_id:{final_review_id}\nfinal_review_url:{final_review_url}")
print("graph execution has completed successfully")
except Exception as e:
print(f"An error occurred during graph resumption: {e}")
return
# TODO: Implement your LangGraph resume logic here.
# This is where you'd typically use graph.invoke or graph.continue
# to push the human input back into your LangGraph workflow
# For example, if your graph expects a state update:
# try:
# # Assuming your LangGraph can be resumed with a new state or event
# # You might need to store the thread_id or run_id when you pause
# # and retrieve it here to continue the specific workflow run.
# # This is a placeholder; adjust based on your LangGraph implementation.
# updated_state = PRReviewState(
# pr_id=payload.review_id,
# human_review_status=payload.status,
# human_review_feedback=payload.feedback_message
# )
# # You'll likely need to pass the thread_id or run_id to resume
# # For demonstration, we'll just print for now.
# # output = graph.invoke(updated_state, config={"configurable": {"thread_id": "your_thread_id_here"}})
# print(f"Successfully processed human input for PR ID {payload.review_id}")
# except Exception as e:
# print(f"Error resuming LangGraph for PR ID {payload.review_id}: {e}")
# raise HTTPException(
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
# detail=f"Failed to resume workflow: {e}"
# )
return {"message": "Human input received and processed successfully!"}