Spaces:
Runtime error
Runtime error
New agentic features
Browse files- agent.py +57 -2
- app.py +64 -29
- drive_tools.py +115 -35
- google_auth.py +0 -60
- google_auth_flow.py +90 -0
- oauth_callback.py +41 -0
- supabase_auth.py +52 -0
agent.py
CHANGED
|
@@ -71,7 +71,6 @@ llm_with_tools = llm.bind_tools(tools)
|
|
| 71 |
class State(TypedDict):
|
| 72 |
messages: Annotated[list,add_messages]
|
| 73 |
|
| 74 |
-
graph_builder=StateGraph(State)
|
| 75 |
|
| 76 |
# ==================== NODES =======================
|
| 77 |
|
|
@@ -93,9 +92,65 @@ def chatbot(state:State):
|
|
| 93 |
])
|
| 94 |
return {"messages":[response]}
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
# ==================== GRAPH =======================
|
| 97 |
|
| 98 |
# Adding Node
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
graph_builder.add_node("chatbot", chatbot)
|
| 100 |
|
| 101 |
tool_node = ToolNode(tools=tools)
|
|
@@ -116,7 +171,7 @@ graph_builder.add_conditional_edges(
|
|
| 116 |
|
| 117 |
graph_builder.add_edge("tools","chatbot")
|
| 118 |
|
| 119 |
-
graph=graph_builder.compile()
|
| 120 |
|
| 121 |
# ==================== ENTRY FUNCTION =======================
|
| 122 |
|
|
|
|
| 71 |
class State(TypedDict):
|
| 72 |
messages: Annotated[list,add_messages]
|
| 73 |
|
|
|
|
| 74 |
|
| 75 |
# ==================== NODES =======================
|
| 76 |
|
|
|
|
| 92 |
])
|
| 93 |
return {"messages":[response]}
|
| 94 |
|
| 95 |
+
def handle_tools(state: State):
|
| 96 |
+
"""
|
| 97 |
+
Custom tool-execution node that intercepts AUTH_REQUIRED:: sentinels
|
| 98 |
+
from the Drive tool and surfaces them via LangGraph interrupt instead
|
| 99 |
+
of letting the agent loop silently.
|
| 100 |
+
"""
|
| 101 |
+
last_message: AIMessage = state["messages"][-1]
|
| 102 |
+
|
| 103 |
+
results = []
|
| 104 |
+
for tool_call in last_message.tool_calls:
|
| 105 |
+
# ── Run the tool ──────────────────────────────────────────────────────
|
| 106 |
+
matched_tool = next(
|
| 107 |
+
(t for t in tools if t.name == tool_call["name"]), None
|
| 108 |
+
)
|
| 109 |
+
if matched_tool is None:
|
| 110 |
+
result_content = f"Unknown tool: {tool_call['name']}"
|
| 111 |
+
else:
|
| 112 |
+
result_content = matched_tool.invoke(tool_call["args"])
|
| 113 |
+
|
| 114 |
+
# ── Auth-gate check ───────────────────────────────────────────────────
|
| 115 |
+
if isinstance(result_content, str) and result_content.startswith(AUTH_REQUIRED_PREFIX):
|
| 116 |
+
# Extract the OAuth URL (everything after the sentinel prefix, first line)
|
| 117 |
+
first_line = result_content.split("\n")[0]
|
| 118 |
+
auth_url = first_line.removeprefix(AUTH_REQUIRED_PREFIX).strip()
|
| 119 |
+
|
| 120 |
+
# Interrupt the graph and surface the URL to the front-end
|
| 121 |
+
# The interrupt value is returned to whoever is streaming the graph.
|
| 122 |
+
interrupt({
|
| 123 |
+
"type": "auth_required",
|
| 124 |
+
"auth_url": auth_url,
|
| 125 |
+
"message": (
|
| 126 |
+
"🔐 Google Drive access is required. "
|
| 127 |
+
"Please authenticate by visiting the link below, then retry your request.\n\n"
|
| 128 |
+
f"👉 {auth_url}"
|
| 129 |
+
),
|
| 130 |
+
})
|
| 131 |
+
# After the user resumes (post-OAuth), return a helpful ToolMessage
|
| 132 |
+
result_content = (
|
| 133 |
+
"Authentication flow initiated. Once you have completed Google sign-in, "
|
| 134 |
+
"please repeat your Drive request."
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
results.append(
|
| 138 |
+
ToolMessage(
|
| 139 |
+
content=result_content,
|
| 140 |
+
tool_call_id=tool_call["id"],
|
| 141 |
+
)
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
return {"messages": results}
|
| 145 |
+
|
| 146 |
+
|
| 147 |
# ==================== GRAPH =======================
|
| 148 |
|
| 149 |
# Adding Node
|
| 150 |
+
memory = MemorySaver()
|
| 151 |
+
|
| 152 |
+
graph_builder=StateGraph(State)
|
| 153 |
+
|
| 154 |
graph_builder.add_node("chatbot", chatbot)
|
| 155 |
|
| 156 |
tool_node = ToolNode(tools=tools)
|
|
|
|
| 171 |
|
| 172 |
graph_builder.add_edge("tools","chatbot")
|
| 173 |
|
| 174 |
+
graph=graph_builder.compile(checkpointer=memory)
|
| 175 |
|
| 176 |
# ==================== ENTRY FUNCTION =======================
|
| 177 |
|
app.py
CHANGED
|
@@ -1,33 +1,68 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
from
|
| 3 |
-
from
|
| 4 |
-
from agent import run_agent
|
| 5 |
-
import json
|
| 6 |
|
| 7 |
-
|
|
|
|
| 8 |
|
| 9 |
-
@app.get("/login")
|
| 10 |
-
def login():
|
| 11 |
-
auth_url, state = get_auth_url()
|
| 12 |
-
return {"auth_url": auth_url}
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
)
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from agent import graph
|
| 3 |
+
from oauth_callback import handle_oauth_callback
|
|
|
|
|
|
|
| 4 |
|
| 5 |
+
# ── Persistent thread so MemorySaver keeps conversation context ───────────────
|
| 6 |
+
THREAD_CONFIG = {"configurable": {"thread_id": "default-thread"}}
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
def chat(user_message: str, history: list) -> str:
|
| 10 |
+
"""Called by the Gradio ChatInterface on each user message."""
|
| 11 |
+
events = graph.stream(
|
| 12 |
+
{"messages": [{"role": "user", "content": user_message}]},
|
| 13 |
+
config=THREAD_CONFIG,
|
| 14 |
+
stream_mode="values",
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
last_ai_text = ""
|
| 18 |
+
for event in events:
|
| 19 |
+
# Check for an interrupt (auth required)
|
| 20 |
+
if "__interrupt__" in event:
|
| 21 |
+
interrupt_val = event["__interrupt__"][0].value
|
| 22 |
+
if interrupt_val.get("type") == "auth_required":
|
| 23 |
+
return interrupt_val["message"]
|
| 24 |
+
|
| 25 |
+
msgs = event.get("messages", [])
|
| 26 |
+
for msg in reversed(msgs):
|
| 27 |
+
if hasattr(msg, "content") and msg.type == "ai" and not msg.tool_calls:
|
| 28 |
+
last_ai_text = msg.content
|
| 29 |
+
break
|
| 30 |
+
|
| 31 |
+
return last_ai_text or "Done."
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# ── OAuth callback endpoint ───────────────────────────────────────────────────
|
| 35 |
+
|
| 36 |
+
def oauth_callback_page(request: gr.Request) -> str:
|
| 37 |
+
"""
|
| 38 |
+
Gradio page that Google redirects to after the user grants consent.
|
| 39 |
+
Mount at /oauth/callback in your Space.
|
| 40 |
+
"""
|
| 41 |
+
params = dict(request.query_params)
|
| 42 |
+
code = params.get("code", "")
|
| 43 |
+
state = params.get("state", "")
|
| 44 |
+
result = handle_oauth_callback(code, state)
|
| 45 |
+
if result["success"]:
|
| 46 |
+
return f"<h2>{result['message']}</h2><p>You can close this tab and return to the chat.</p>"
|
| 47 |
+
return f"<h2>❌ Authentication failed</h2><p>{result['message']}</p>"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# ── Gradio UI ─────────────────────────────────────────────────────────────────
|
| 51 |
+
|
| 52 |
+
with gr.Blocks(title="AI Agent") as demo:
|
| 53 |
+
gr.Markdown("## 🤖 AI Agent | Email · Google Drive")
|
| 54 |
+
|
| 55 |
+
with gr.Tab("Chat"):
|
| 56 |
+
gr.ChatInterface(fn=chat)
|
| 57 |
+
|
| 58 |
+
# Hidden page — Google redirects here after OAuth
|
| 59 |
+
with gr.Tab("OAuth Callback", visible=False) as callback_tab:
|
| 60 |
+
callback_output = gr.HTML()
|
| 61 |
+
|
| 62 |
+
# Route /oauth/callback → the handler above
|
| 63 |
+
demo.load(fn=None) # placeholder; real routing done via gr.mount_gradio_app or FastAPI
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# ── For local dev, run directly ───────────────────────────────────────────────
|
| 67 |
+
if __name__ == "__main__":
|
| 68 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
drive_tools.py
CHANGED
|
@@ -1,49 +1,129 @@
|
|
| 1 |
-
from langchain_core.tools import tool
|
| 2 |
-
from googleapiclient.discovery import build
|
| 3 |
-
from googleapiclient.http import MediaIoBaseDownload
|
| 4 |
-
from google_auth import dict_to_creds
|
| 5 |
import io
|
| 6 |
import os
|
| 7 |
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
@tool
|
| 12 |
-
def search_and_download_doc_tool(file_name: str) -> str:
|
| 13 |
-
"""
|
| 14 |
-
Searches Google Drive and downloads a document by name.
|
| 15 |
-
"""
|
| 16 |
-
|
| 17 |
-
if not os.path.exists(TOKEN_STORE):
|
| 18 |
-
return "User not authenticated. Please login first."
|
| 19 |
-
|
| 20 |
-
with open(TOKEN_STORE, "r") as f:
|
| 21 |
-
token_data = json.load(f)
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
|
|
|
|
|
|
| 26 |
results = service.files().list(
|
| 27 |
-
q=f"
|
| 28 |
-
|
|
|
|
| 29 |
).execute()
|
|
|
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
downloader = MediaIoBaseDownload(fh, request)
|
| 44 |
-
|
| 45 |
done = False
|
| 46 |
while not done:
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import io
|
| 2 |
import os
|
| 3 |
import json
|
| 4 |
+
from langchain_core.tools import tool
|
| 5 |
+
from googleapiclient.discovery import build
|
| 6 |
+
from googleapiclient.http import MediaIoBaseDownload
|
| 7 |
+
from googleapiclient.errors import HttpError
|
| 8 |
+
|
| 9 |
+
from supabase_auth import get_token, save_token
|
| 10 |
+
from google_auth_flow import credentials_from_token_dict, get_auth_url
|
| 11 |
|
| 12 |
+
AUTH_REQUIRED_PREFIX = "AUTH_REQUIRED::"
|
| 13 |
+
DOWNLOAD_DIR = os.getenv("DOWNLOAD_DIR", "/tmp/drive_downloads")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
def _drive_service(user_email: str):
|
| 16 |
+
"""Returns an auth Drive service"""
|
| 17 |
+
token_dict = get_token(user_email)
|
| 18 |
+
if not token_dict:
|
| 19 |
+
return None
|
| 20 |
+
creds = credentials_from_token_dict(token_dict)
|
| 21 |
+
# Persist refreshed token back to Supabase
|
| 22 |
+
save_token(user_email, {
|
| 23 |
+
"token": creds.token,
|
| 24 |
+
"refresh_token": creds.refresh_token,
|
| 25 |
+
"token_uri": creds.token_uri,
|
| 26 |
+
"client_id": creds.client_id,
|
| 27 |
+
"client_secret": creds.client_secret,
|
| 28 |
+
"scopes": list(creds.scopes or []),
|
| 29 |
+
"expiry": creds.expiry.isoformat() if creds.expiry else None,
|
| 30 |
+
})
|
| 31 |
+
return build("drive", "v3", credentials=creds)
|
| 32 |
|
| 33 |
+
def _search_files(service, query: str, max_results: int = 5) -> list[dict]:
|
| 34 |
+
"""Full-text search across Drive files."""
|
| 35 |
results = service.files().list(
|
| 36 |
+
q=f"fullText contains '{query}' and trashed=false",
|
| 37 |
+
pageSize=max_results,
|
| 38 |
+
fields="files(id, name, mimeType, webViewLink, size)",
|
| 39 |
).execute()
|
| 40 |
+
return results.get("files", [])
|
| 41 |
|
| 42 |
+
def _download_file(service, file_id: str, file_name: str, mime_type: str) -> str:
|
| 43 |
+
"""Downloads a file and returns local path."""
|
| 44 |
+
os.makedirs(DOWNLOAD_DIR, exist_ok=True)
|
| 45 |
+
local_path = os.path.join(DOWNLOAD_DIR, file_name)
|
| 46 |
+
|
| 47 |
+
# Google Workspace docs must be exported; regular files use get_media
|
| 48 |
+
export_map = {
|
| 49 |
+
"application/vnd.google-apps.document":
|
| 50 |
+
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
| 51 |
+
"application/vnd.google-apps.spreadsheet":
|
| 52 |
+
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
| 53 |
+
"application/vnd.google-apps.presentation":
|
| 54 |
+
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
| 55 |
+
}
|
| 56 |
|
| 57 |
+
fh = io.BytesIO()
|
| 58 |
+
if mime_type in export_map:
|
| 59 |
+
export_mime = export_map[mime_type]
|
| 60 |
+
request = service.files().export_media(fileId=file_id, mimeType=export_mime)
|
| 61 |
+
ext_map = {
|
| 62 |
+
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
| 63 |
+
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
| 64 |
+
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
|
| 65 |
+
}
|
| 66 |
+
local_path += ext_map.get(export_mime, "")
|
| 67 |
+
else:
|
| 68 |
+
request = service.files().get_media(fileId=file_id)
|
| 69 |
+
|
| 70 |
downloader = MediaIoBaseDownload(fh, request)
|
|
|
|
| 71 |
done = False
|
| 72 |
while not done:
|
| 73 |
+
_, done = downloader.next_chunk()
|
| 74 |
+
|
| 75 |
+
with open(local_path, "wb") as f:
|
| 76 |
+
f.write(fh.getvalue())
|
| 77 |
+
|
| 78 |
+
return local_path
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@tool
|
| 83 |
+
def search_and_download_doc_tool(user_email: str, query: str) -> str:
|
| 84 |
+
"""
|
| 85 |
+
Searches Google Drive and downloads a document by name.
|
| 86 |
+
"""
|
| 87 |
|
| 88 |
+
service = _drive_service(user_email)
|
| 89 |
+
if service is None:
|
| 90 |
+
auth_url = get_auth_url(state=user_email) # state carries email through OAuth
|
| 91 |
+
return (
|
| 92 |
+
f"{AUTH_REQUIRED_PREFIX}{auth_url}\n"
|
| 93 |
+
f"User {user_email} is not authenticated with Google Drive. "
|
| 94 |
+
f"They must visit the URL above to grant access."
|
| 95 |
+
)
|
| 96 |
+
try:
|
| 97 |
+
files = _search_files(service, query)
|
| 98 |
+
except HttpError as e:
|
| 99 |
+
return f"Drive search failed: {e}"
|
| 100 |
+
|
| 101 |
+
if not files:
|
| 102 |
+
return f"No files found on Google Drive matching '{query}'."
|
| 103 |
+
|
| 104 |
+
# Pick the first (most relevant) result
|
| 105 |
+
best = files[0]
|
| 106 |
+
file_id = best["id"]
|
| 107 |
+
file_name = best["name"]
|
| 108 |
+
mime_type = best["mimeType"]
|
| 109 |
+
view_link = best.get("webViewLink", "N/A")
|
| 110 |
+
|
| 111 |
+
other_matches = [f["name"] for f in files[1:]]
|
| 112 |
+
other_str = (
|
| 113 |
+
f"\n\nOther matches: {', '.join(other_matches)}" if other_matches else ""
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# ── 3. Download ────────────────────────────────────────────────────────────
|
| 117 |
+
try:
|
| 118 |
+
local_path = _download_file(service, file_id, file_name, mime_type)
|
| 119 |
+
return (
|
| 120 |
+
f"✅ Found and downloaded '{file_name}'.\n"
|
| 121 |
+
f"Saved to: {local_path}\n"
|
| 122 |
+
f"View online: {view_link}"
|
| 123 |
+
f"{other_str}"
|
| 124 |
+
)
|
| 125 |
+
except HttpError as e:
|
| 126 |
+
return (
|
| 127 |
+
f"Found '{file_name}' on Drive but download failed: {e}\n"
|
| 128 |
+
f"View online: {view_link}"
|
| 129 |
+
)
|
google_auth.py
DELETED
|
@@ -1,60 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import json
|
| 3 |
-
from google_auth_oauthlib.flow import Flow
|
| 4 |
-
from google.oauth2.credentials import Credentials
|
| 5 |
-
|
| 6 |
-
SCOPES = ['https://www.googleapis.com/auth/drive.readonly']
|
| 7 |
-
|
| 8 |
-
REDIRECT_URI = os.getenv("REDIRECT_URI")
|
| 9 |
-
|
| 10 |
-
def create_flow():
|
| 11 |
-
client_config = {
|
| 12 |
-
"web": {
|
| 13 |
-
"client_id": os.getenv("GOOGLE_CLIENT_ID"),
|
| 14 |
-
"client_secret": os.getenv("GOOGLE_CLIENT_SECRET"),
|
| 15 |
-
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
| 16 |
-
"token_uri": "https://oauth2.googleapis.com/token"
|
| 17 |
-
}
|
| 18 |
-
}
|
| 19 |
-
|
| 20 |
-
flow = Flow.from_client_config(
|
| 21 |
-
client_config,
|
| 22 |
-
scopes=SCOPES,
|
| 23 |
-
redirect_uri=REDIRECT_URI
|
| 24 |
-
)
|
| 25 |
-
|
| 26 |
-
return flow
|
| 27 |
-
|
| 28 |
-
def get_auth_url():
|
| 29 |
-
flow = create_flow()
|
| 30 |
-
auth_url, state = flow.authorization_url(
|
| 31 |
-
access_type='offline',
|
| 32 |
-
include_granted_scopes='true'
|
| 33 |
-
)
|
| 34 |
-
return auth_url, state
|
| 35 |
-
|
| 36 |
-
def fetch_token(code):
|
| 37 |
-
flow = create_flow()
|
| 38 |
-
flow.fetch_token(code=code)
|
| 39 |
-
creds = flow.credentials
|
| 40 |
-
return creds_to_dict(creds)
|
| 41 |
-
|
| 42 |
-
def creds_to_dict(creds):
|
| 43 |
-
return {
|
| 44 |
-
"token": creds.token,
|
| 45 |
-
"refresh_token": creds.refresh_token,
|
| 46 |
-
"token_uri": creds.token_uri,
|
| 47 |
-
"client_id": creds.client_id,
|
| 48 |
-
"client_secret": creds.client_secret,
|
| 49 |
-
"scopes": creds.scopes
|
| 50 |
-
}
|
| 51 |
-
|
| 52 |
-
def dict_to_creds(data):
|
| 53 |
-
return Credentials(
|
| 54 |
-
token=data["token"],
|
| 55 |
-
refresh_token=data["refresh_token"],
|
| 56 |
-
token_uri=data["token_uri"],
|
| 57 |
-
client_id=data["client_id"],
|
| 58 |
-
client_secret=data["client_secret"],
|
| 59 |
-
scopes=data["scopes"]
|
| 60 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
google_auth_flow.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from google_auth_oauthlib.flow import Flow
|
| 4 |
+
from google.oauth2.credentials import Credentials
|
| 5 |
+
from google.auth.transport.requests import Request
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
|
| 8 |
+
SCOPES = [
|
| 9 |
+
"https://www.googleapis.com/auth/drive.readonly",
|
| 10 |
+
"https://www.googleapis.com/auth/userinfo.email",
|
| 11 |
+
"openid",
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
load_dotenv()
|
| 15 |
+
|
| 16 |
+
CLIENT_ID = os.getenv("GOOGLE_CLIENT_ID", "")
|
| 17 |
+
CLIENT_SECRET = os.getenv("GOOGLE_CLIENT_SECRET", "")
|
| 18 |
+
REDIRECT_URI = os.getenv("GOOGLE_REDIRECT_URI", "")
|
| 19 |
+
|
| 20 |
+
def _client_config() -> dict:
|
| 21 |
+
"""Builds the client config dict that google_auth_oauthlib expects."""
|
| 22 |
+
return {
|
| 23 |
+
"web": {
|
| 24 |
+
"client_id": CLIENT_ID,
|
| 25 |
+
"client_secret": CLIENT_SECRET,
|
| 26 |
+
"redirect_uris": [REDIRECT_URI],
|
| 27 |
+
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
| 28 |
+
"token_uri": "https://oauth2.googleapis.com/token",
|
| 29 |
+
}
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
def get_auth_url(state: str | None = None) -> str:
|
| 33 |
+
"""
|
| 34 |
+
Returns the Google OAuth consent-screen URL to redirect the user to.
|
| 35 |
+
`state` can carry any context you want back in the callback (e.g. user_email).
|
| 36 |
+
"""
|
| 37 |
+
flow = Flow.from_client_config(_client_config(), scopes=SCOPES)
|
| 38 |
+
flow.redirect_uri = REDIRECT_URI
|
| 39 |
+
auth_url, _ = flow.authorization_url(
|
| 40 |
+
access_type="offline", # get refresh_token
|
| 41 |
+
include_granted_scopes="true",
|
| 42 |
+
prompt="consent", # force refresh_token every time during dev
|
| 43 |
+
state=state or "",
|
| 44 |
+
)
|
| 45 |
+
return auth_url
|
| 46 |
+
|
| 47 |
+
def exchange_code_for_token(code: str) -> dict:
|
| 48 |
+
"""
|
| 49 |
+
Exchanges an authorization code (from the OAuth callback) for credentials.
|
| 50 |
+
Returns a JSON-serialisable token dict.
|
| 51 |
+
"""
|
| 52 |
+
flow = Flow.from_client_config(_client_config(), scopes=SCOPES)
|
| 53 |
+
flow.redirect_uri = REDIRECT_URI
|
| 54 |
+
flow.fetch_token(code=code)
|
| 55 |
+
creds = flow.credentials
|
| 56 |
+
return _creds_to_dict(creds)
|
| 57 |
+
|
| 58 |
+
def credentials_from_token_dict(token_dict: dict) -> Credentials:
|
| 59 |
+
"""
|
| 60 |
+
Re-hydrates a Credentials object from a stored token dict,
|
| 61 |
+
refreshing automatically if the access token is expired.
|
| 62 |
+
"""
|
| 63 |
+
creds = Credentials(
|
| 64 |
+
token=token_dict.get("token"),
|
| 65 |
+
refresh_token=token_dict.get("refresh_token"),
|
| 66 |
+
token_uri="https://oauth2.googleapis.com/token",
|
| 67 |
+
client_id=CLIENT_ID,
|
| 68 |
+
client_secret=CLIENT_SECRET,
|
| 69 |
+
scopes=token_dict.get("scopes", SCOPES),
|
| 70 |
+
)
|
| 71 |
+
if creds.expired and creds.refresh_token:
|
| 72 |
+
creds.refresh(Request())
|
| 73 |
+
return creds
|
| 74 |
+
|
| 75 |
+
def _creds_to_dict(creds: Credentials) -> dict:
|
| 76 |
+
return {
|
| 77 |
+
"token": creds.token,
|
| 78 |
+
"refresh_token": creds.refresh_token,
|
| 79 |
+
"token_uri": creds.token_uri,
|
| 80 |
+
"client_id": creds.client_id,
|
| 81 |
+
"client_secret": creds.client_secret,
|
| 82 |
+
"scopes": list(creds.scopes or SCOPES),
|
| 83 |
+
"expiry": creds.expiry.isoformat() if creds.expiry else None,
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
|
oauth_callback.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Handles the google OAuth 2.0 redirect callback"""
|
| 2 |
+
|
| 3 |
+
from google_auth_flow import exchange_code_for_token
|
| 4 |
+
from supabase_auth import save_token
|
| 5 |
+
|
| 6 |
+
def handle_oauth_callback(code: str, state: str) -> dict:
|
| 7 |
+
"""
|
| 8 |
+
Exchanges the OAuth authorization code for tokens and persists them.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
code: The `code` query parameter from the callback URL.
|
| 12 |
+
state: The `state` parameter — we use it to carry the user's email.
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
A dict with keys:
|
| 16 |
+
- success (bool)
|
| 17 |
+
- user_email (str)
|
| 18 |
+
- message (str)
|
| 19 |
+
"""
|
| 20 |
+
user_email = state # we set state=user_email when building the auth URL
|
| 21 |
+
|
| 22 |
+
if not code:
|
| 23 |
+
return {"success": False, "user_email": user_email, "message": "No authorization code received."}
|
| 24 |
+
if not user_email:
|
| 25 |
+
return {"success": False, "user_email": "", "message": "No user email in OAuth state parameter."}
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
token_dict = exchange_code_for_token(code)
|
| 29 |
+
save_token(user_email, token_dict)
|
| 30 |
+
return {
|
| 31 |
+
"success": True,
|
| 32 |
+
"user_email": user_email,
|
| 33 |
+
"message": f"✅ Google Drive access granted and token saved for {user_email}. You can now search your Drive.",
|
| 34 |
+
}
|
| 35 |
+
except Exception as e:
|
| 36 |
+
return {
|
| 37 |
+
"success": False,
|
| 38 |
+
"user_email": user_email,
|
| 39 |
+
"message": f"OAuth token exchange failed: {str(e)}",
|
| 40 |
+
}
|
| 41 |
+
|
supabase_auth.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Handles storing and retrieving Google OAuth tokens from Supabase."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
from supabase import create_client, Client
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
|
| 8 |
+
load_dotenv()
|
| 9 |
+
|
| 10 |
+
SUPABASE_URL: str = os.getenv("SUPABASE_URL", "")
|
| 11 |
+
SUPABASE_KEY: str = os.getenv("SUPABASE_SERVICE_KEY", "")
|
| 12 |
+
|
| 13 |
+
def _get_client() -> Client:
|
| 14 |
+
if not SUPABASE_URL or not SUPABASE_KEY:
|
| 15 |
+
raise EnvironmentError("SUPABASE_URL and SUPABASE_SERVICE_KEY must be set in .env")
|
| 16 |
+
return create_client(SUPABASE_URL, SUPABASE_KEY)
|
| 17 |
+
|
| 18 |
+
def get_token(user_email: str) -> dict | None:
|
| 19 |
+
"""
|
| 20 |
+
Returns the stored Google OAuth token dict for the user, or None if not found.
|
| 21 |
+
"""
|
| 22 |
+
client = _get_client()
|
| 23 |
+
result = (
|
| 24 |
+
client.table(TABLE)
|
| 25 |
+
.select("token_json")
|
| 26 |
+
.eq("user_email", user_email)
|
| 27 |
+
.maybe_single()
|
| 28 |
+
.execute()
|
| 29 |
+
)
|
| 30 |
+
if result.data:
|
| 31 |
+
return json.loads(result.data["token_json"])
|
| 32 |
+
return None
|
| 33 |
+
|
| 34 |
+
def save_token(user_email: str, token_dict: dict) -> None:
|
| 35 |
+
"""
|
| 36 |
+
Upserts (insert or update) the Google OAuth token for the user.
|
| 37 |
+
"""
|
| 38 |
+
client = _get_client()
|
| 39 |
+
client.table(TABLE).upsert(
|
| 40 |
+
{
|
| 41 |
+
"user_email": user_email,
|
| 42 |
+
"token_json": json.dumps(token_dict),
|
| 43 |
+
}
|
| 44 |
+
).execute()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def delete_token(user_email: str) -> None:
|
| 48 |
+
"""
|
| 49 |
+
Removes the stored token (e.g. on logout or revocation).
|
| 50 |
+
"""
|
| 51 |
+
client = _get_client()
|
| 52 |
+
client.table(TABLE).delete().eq("user_email", user_email).execute()
|