Spaces:
Sleeping
Sleeping
Upload 3 files
Browse files- pydantic_ai_agent.py +142 -0
- requirements.txt +0 -0
- streamlit_ui.py +161 -0
pydantic_ai_agent.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations as _annotations
|
| 2 |
+
import datetime
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
import logfire
|
| 6 |
+
import asyncio
|
| 7 |
+
import os
|
| 8 |
+
from supabase import create_client
|
| 9 |
+
from pydantic_ai import Agent, RunContext
|
| 10 |
+
from pydantic_ai.models.gemini import GeminiModel
|
| 11 |
+
from supabase import Client
|
| 12 |
+
from typing import List
|
| 13 |
+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
| 14 |
+
|
| 15 |
+
load_dotenv()
|
| 16 |
+
|
| 17 |
+
# Load API keys and initialize Supabase client
|
| 18 |
+
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
| 19 |
+
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
| 20 |
+
|
| 21 |
+
supabase: Client = create_client(
|
| 22 |
+
os.getenv("SUPABASE_URL1"),
|
| 23 |
+
os.getenv("SUPABASE_SERVICE_KEY1")
|
| 24 |
+
)
|
| 25 |
+
logfire.configure(send_to_logfire='always')
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class PydanticAIDeps:
|
| 29 |
+
supabase: Client
|
| 30 |
+
gemini: GeminiModel
|
| 31 |
+
|
| 32 |
+
# Updated system prompt explicitly instructs the agent to use the tools for documentation queries.
|
| 33 |
+
system_prompt = """
|
| 34 |
+
You are a PineScript expert with direct access to complete documentation through tool functions.
|
| 35 |
+
Whenever a user asks for documentation, examples, or strategies for PineScript, you MUST:
|
| 36 |
+
1. List all available documentation pages (using the tool "list_documentation_pages").
|
| 37 |
+
2. Retrieve detailed, relevant documentation (using the tool "retrieve_relevant_documentation").
|
| 38 |
+
Do not generate a direct answer from your internal data.
|
| 39 |
+
If no documentation is found, clearly state so.
|
| 40 |
+
Then, combine the retrieved information to answer the query accurately.
|
| 41 |
+
IMPORTANT: When a user asks "what is your name" or similar, always respond with "PineScript expert"
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
model = GeminiModel('gemini-1.5-flash', api_key=GOOGLE_API_KEY)
|
| 45 |
+
PineScript_expert = Agent(
|
| 46 |
+
model,
|
| 47 |
+
system_prompt=system_prompt,
|
| 48 |
+
deps_type=PydanticAIDeps,
|
| 49 |
+
retries=2
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
async def get_embedding(text: str) -> List[float]:
|
| 53 |
+
try:
|
| 54 |
+
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", api_key=GOOGLE_API_KEY)
|
| 55 |
+
vector = embeddings.embed_query(text)
|
| 56 |
+
return vector
|
| 57 |
+
except Exception as e:
|
| 58 |
+
print(f"Error getting embedding: {e}")
|
| 59 |
+
return [0] * 768
|
| 60 |
+
|
| 61 |
+
@PineScript_expert.tool
|
| 62 |
+
async def retrieve_relevant_documentation(ctx: RunContext[PydanticAIDeps], user_query: str) -> str:
|
| 63 |
+
"""
|
| 64 |
+
Retrieve relevant documentation chunks based on the query.
|
| 65 |
+
"""
|
| 66 |
+
try:
|
| 67 |
+
print("DEBUG: retrieve_relevant_documentation tool function triggered.")
|
| 68 |
+
# Log the user query
|
| 69 |
+
log_result = ctx.deps.supabase.table('user_queries').insert({
|
| 70 |
+
'query': user_query,
|
| 71 |
+
'timestamp': datetime.datetime.now().isoformat()
|
| 72 |
+
}).execute()
|
| 73 |
+
error_info = getattr(log_result, "error", None)
|
| 74 |
+
if error_info:
|
| 75 |
+
print("Insert error:", error_info)
|
| 76 |
+
else:
|
| 77 |
+
print("Inserted user query successfully.")
|
| 78 |
+
# Get query embedding and retrieve matched documents
|
| 79 |
+
query_embedding = await get_embedding(user_query)
|
| 80 |
+
result = ctx.deps.supabase.rpc(
|
| 81 |
+
'match_site_pages',
|
| 82 |
+
{
|
| 83 |
+
'query_embedding': query_embedding,
|
| 84 |
+
'match_count': 20,
|
| 85 |
+
'filter': {'source': 'pydantic_ai_docs'}
|
| 86 |
+
}
|
| 87 |
+
).execute()
|
| 88 |
+
if not result.data:
|
| 89 |
+
return "I couldn't find any relevant documentation for that query."
|
| 90 |
+
formatted_chunks = []
|
| 91 |
+
for doc in result.data:
|
| 92 |
+
chunks_text = f"""
|
| 93 |
+
# {doc['title']}
|
| 94 |
+
|
| 95 |
+
{doc['content']}
|
| 96 |
+
"""
|
| 97 |
+
formatted_chunks.append(chunks_text)
|
| 98 |
+
return "\n\n---\n\n".join(formatted_chunks)
|
| 99 |
+
except Exception as e:
|
| 100 |
+
print(f"Error retrieving documentation: {e}")
|
| 101 |
+
return f"An error occurred while retrieving the documentation: {str(e)}"
|
| 102 |
+
|
| 103 |
+
@PineScript_expert.tool
|
| 104 |
+
async def list_documentation_pages(ctx: RunContext[PydanticAIDeps]) -> List[str]:
|
| 105 |
+
"""
|
| 106 |
+
List all available PineScript documentation pages.
|
| 107 |
+
"""
|
| 108 |
+
try:
|
| 109 |
+
result = ctx.deps.supabase.from_('site_pages') \
|
| 110 |
+
.select('url') \
|
| 111 |
+
.eq('metadata->>source', 'pydantic_ai_docs') \
|
| 112 |
+
.execute()
|
| 113 |
+
if not result.data:
|
| 114 |
+
return []
|
| 115 |
+
urls = sorted(set(doc['url'] for doc in result.data))
|
| 116 |
+
return urls
|
| 117 |
+
except Exception as e:
|
| 118 |
+
print(f"Error retrieving documentation pages: {e}")
|
| 119 |
+
return []
|
| 120 |
+
|
| 121 |
+
@PineScript_expert.tool
|
| 122 |
+
async def get_page_content(ctx: RunContext[PydanticAIDeps], url: str) -> str:
|
| 123 |
+
"""
|
| 124 |
+
Retrieve the full content of a specific documentation page.
|
| 125 |
+
"""
|
| 126 |
+
try:
|
| 127 |
+
result = ctx.deps.supabase.from_('site_pages') \
|
| 128 |
+
.select('title, content, chunk_number') \
|
| 129 |
+
.eq('url', url) \
|
| 130 |
+
.eq('metadata->>source', 'pydantic_ai_docs') \
|
| 131 |
+
.order('chunk_number') \
|
| 132 |
+
.execute()
|
| 133 |
+
if not result.data:
|
| 134 |
+
return f"No content found for URL: {url}"
|
| 135 |
+
page_title = result.data[0]['title'].split(' - ')[0]
|
| 136 |
+
formatted_content = [f"# {page_title}\n"]
|
| 137 |
+
for chunk in result.data:
|
| 138 |
+
formatted_content.append(chunk['content'])
|
| 139 |
+
return "\n\n".join(formatted_content)
|
| 140 |
+
except Exception as e:
|
| 141 |
+
print(f"Error retrieving page content: {e}")
|
| 142 |
+
return f"Error retrieving page content: {str(e)}"
|
requirements.txt
ADDED
|
Binary file (210 Bytes). View file
|
|
|
streamlit_ui.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from typing import Literal, TypedDict
|
| 3 |
+
import asyncio
|
| 4 |
+
import os
|
| 5 |
+
import streamlit as st
|
| 6 |
+
import json
|
| 7 |
+
import datetime
|
| 8 |
+
import logfire
|
| 9 |
+
from supabase import Client
|
| 10 |
+
from pydantic_ai.models.gemini import GeminiModel
|
| 11 |
+
from supabase import create_client
|
| 12 |
+
|
| 13 |
+
# Import all the message part classes
|
| 14 |
+
from pydantic_ai.messages import (
|
| 15 |
+
ModelMessage,
|
| 16 |
+
ModelRequest,
|
| 17 |
+
ModelResponse,
|
| 18 |
+
SystemPromptPart,
|
| 19 |
+
UserPromptPart,
|
| 20 |
+
TextPart,
|
| 21 |
+
ToolCallPart,
|
| 22 |
+
ToolReturnPart,
|
| 23 |
+
RetryPromptPart,
|
| 24 |
+
ModelMessagesTypeAdapter
|
| 25 |
+
)
|
| 26 |
+
from pydantic_ai_agent import PineScript_expert, PydanticAIDeps
|
| 27 |
+
|
| 28 |
+
# Load environment variables
|
| 29 |
+
from dotenv import load_dotenv
|
| 30 |
+
load_dotenv()
|
| 31 |
+
|
| 32 |
+
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
| 33 |
+
|
| 34 |
+
# Initialize Supabase client
|
| 35 |
+
supabase: Client = create_client(
|
| 36 |
+
os.getenv("SUPABASE_URL1"),
|
| 37 |
+
os.getenv("SUPABASE_SERVICE_KEY1")
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# Configure logfire to suppress warnings (optional)
|
| 41 |
+
logfire.configure(send_to_logfire='always')
|
| 42 |
+
|
| 43 |
+
class ChatMessage(TypedDict):
|
| 44 |
+
"""Format of messages sent to the browser/API."""
|
| 45 |
+
role: Literal['user', 'model']
|
| 46 |
+
timestamp: str
|
| 47 |
+
content: str
|
| 48 |
+
|
| 49 |
+
def display_message_part(part):
|
| 50 |
+
"""
|
| 51 |
+
Display a single part of a message in the Streamlit UI.
|
| 52 |
+
Customize how you display system prompts, user prompts,
|
| 53 |
+
tool calls, tool returns, etc.
|
| 54 |
+
"""
|
| 55 |
+
# system-prompt
|
| 56 |
+
if part.part_kind == 'system-prompt':
|
| 57 |
+
with st.chat_message("system"):
|
| 58 |
+
st.markdown(f"**System**: {part.content}")
|
| 59 |
+
# user-prompt
|
| 60 |
+
elif part.part_kind == 'user-prompt':
|
| 61 |
+
with st.chat_message("user"):
|
| 62 |
+
st.markdown(part.content)
|
| 63 |
+
# text
|
| 64 |
+
elif part.part_kind == 'text':
|
| 65 |
+
with st.chat_message("assistant"):
|
| 66 |
+
st.markdown(part.content)
|
| 67 |
+
|
| 68 |
+
async def run_agent_with_streaming(user_input: str):
|
| 69 |
+
"""
|
| 70 |
+
Run the agent with streaming text for the user_input prompt,
|
| 71 |
+
while maintaining the entire conversation in `st.session_state.messages`.
|
| 72 |
+
"""
|
| 73 |
+
# Prepare dependencies for the agent
|
| 74 |
+
deps = PydanticAIDeps(
|
| 75 |
+
supabase=supabase,
|
| 76 |
+
gemini=GeminiModel('gemini-1.5-flash', api_key=GOOGLE_API_KEY)
|
| 77 |
+
)
|
| 78 |
+
message_placeholder = st.empty()
|
| 79 |
+
partial_text = ""
|
| 80 |
+
try:
|
| 81 |
+
# Run the agent in a stream
|
| 82 |
+
async with PineScript_expert.run_stream(
|
| 83 |
+
user_input,
|
| 84 |
+
deps=deps,
|
| 85 |
+
message_history=st.session_state.messages[:-1], # pass entire conversation so far
|
| 86 |
+
) as result:
|
| 87 |
+
# Render partial text as it arrives
|
| 88 |
+
async for chunk in result.stream_text(delta=True):
|
| 89 |
+
partial_text += chunk
|
| 90 |
+
message_placeholder.markdown(partial_text)
|
| 91 |
+
except Exception as e:
|
| 92 |
+
# Instead of displaying the raw error, display a friendly message.
|
| 93 |
+
message_placeholder.markdown("**Could you please ask again?**")
|
| 94 |
+
st.error("Could you please ask again?")
|
| 95 |
+
# Also, add the friendly message to the conversation history.
|
| 96 |
+
st.session_state.messages.append(
|
| 97 |
+
ModelResponse(parts=[TextPart(content="Could you please ask again?")])
|
| 98 |
+
)
|
| 99 |
+
return
|
| 100 |
+
|
| 101 |
+
# Now that the stream is finished, add new messages from this run
|
| 102 |
+
filtered_messages = [
|
| 103 |
+
msg for msg in result.new_messages()
|
| 104 |
+
if not (hasattr(msg, 'parts') and any(part.part_kind == 'user-prompt' for part in msg.parts))
|
| 105 |
+
]
|
| 106 |
+
st.session_state.messages.extend(filtered_messages)
|
| 107 |
+
st.session_state.messages.append(
|
| 108 |
+
ModelResponse(parts=[TextPart(content=partial_text)])
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def store_user_query(user_query: str):
|
| 113 |
+
"""
|
| 114 |
+
Stores the user's query in the 'user_queries' table.
|
| 115 |
+
"""
|
| 116 |
+
try:
|
| 117 |
+
result = supabase.table('user_queries').insert({
|
| 118 |
+
'query': user_query,
|
| 119 |
+
'timestamp': datetime.datetime.now().isoformat()
|
| 120 |
+
}).execute()
|
| 121 |
+
print("User query stored successfully.")
|
| 122 |
+
except Exception as e:
|
| 123 |
+
print(f"Error storing user query: {e}")
|
| 124 |
+
|
| 125 |
+
async def main():
|
| 126 |
+
st.title("PineScript Agentic RAG")
|
| 127 |
+
st.write("Ask any question about PineScript, the hidden truths of its beauty lie within.")
|
| 128 |
+
|
| 129 |
+
# Initialize chat history in session state if not present
|
| 130 |
+
if "messages" not in st.session_state:
|
| 131 |
+
st.session_state.messages = []
|
| 132 |
+
|
| 133 |
+
# Display all messages from the conversation so far
|
| 134 |
+
# Each message is either a ModelRequest or ModelResponse.
|
| 135 |
+
# We iterate over their parts to decide how to display them.
|
| 136 |
+
for msg in st.session_state.messages:
|
| 137 |
+
if isinstance(msg, ModelRequest) or isinstance(msg, ModelResponse):
|
| 138 |
+
for part in msg.parts:
|
| 139 |
+
display_message_part(part)
|
| 140 |
+
|
| 141 |
+
# Chat input for the user
|
| 142 |
+
user_input = st.chat_input("What questions do you have about PineScript?")
|
| 143 |
+
|
| 144 |
+
if user_input:
|
| 145 |
+
# We append a new request to the conversation explicitly
|
| 146 |
+
store_user_query(user_input)
|
| 147 |
+
st.session_state.messages.append(
|
| 148 |
+
ModelRequest(parts=[UserPromptPart(content=user_input)])
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Display user prompt in the UI
|
| 152 |
+
with st.chat_message("user"):
|
| 153 |
+
st.markdown(user_input)
|
| 154 |
+
|
| 155 |
+
# Display the assistant's partial response while streaming
|
| 156 |
+
with st.chat_message("assistant"):
|
| 157 |
+
# Actually run the agent now, streaming the text
|
| 158 |
+
await run_agent_with_streaming(user_input)
|
| 159 |
+
|
| 160 |
+
if __name__ == "__main__":
|
| 161 |
+
asyncio.run(main())
|