gridguide / app.py
vprzybylo
fix: deprecated use_container_width
9bb8a5e
import logging
import os
import sys
import tempfile
from pathlib import Path
from typing import Annotated, TypedDict
import openai
import requests
import streamlit as st
from audio_recorder_streamlit import audio_recorder
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain_core.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
)
from langchain_core.tools import Tool
from langchain_openai import ChatOpenAI
from langgraph.graph.message import add_messages
from PIL import Image
sys.path.append(str(Path(__file__).parent))
from dotenv import load_dotenv
# Now import from app.src
from app.src.embedding.model import EmbeddingModel
from app.src.rag.chain import RAGChain
from app.src.rag.document_loader import GridCodeLoader
from app.src.rag.vectorstore import VectorStore
# Load .env file from base directory
load_dotenv(Path(__file__).parent / ".env")
logger = logging.getLogger(__name__)
def get_secrets():
"""Get secrets from environment variables."""
# Skip trying Streamlit secrets and go straight to environment variables
return {
"OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"),
"LANGCHAIN_API_KEY": os.getenv("LANGCHAIN_API_KEY"),
"LANGCHAIN_PROJECT": os.getenv("LANGCHAIN_PROJECT", "GridGuide"),
"LANGCHAIN_TRACING_V2": os.getenv("LANGCHAIN_TRACING_V2", "true"),
}
# Set up environment variables from secrets
secrets = get_secrets()
for key, value in secrets.items():
if value:
os.environ[key] = value
# Verify API keys without showing warning
if not os.getenv("OPENAI_API_KEY"):
st.error("OpenAI API key not found. Please check your .env file.")
st.stop()
class WeatherTool:
def __init__(self):
self.base_url = "https://api.weather.gov"
self.headers = {
"User-Agent": "(Grid Code Assistant, contact@example.com)",
"Accept": "application/json",
}
def get_coordinates_from_zip(self, zipcode):
response = requests.get(f"https://api.zippopotam.us/us/{zipcode}")
if response.status_code == 200:
data = response.json()
return {
"lat": data["places"][0]["latitude"],
"lon": data["places"][0]["longitude"],
"place": data["places"][0]["place name"],
"state": data["places"][0]["state"],
}
return None
def run(self, zipcode):
coords = self.get_coordinates_from_zip(zipcode)
if not coords:
return {"error": "Invalid ZIP code or unable to get coordinates."}
point_url = f"{self.base_url}/points/{coords['lat']},{coords['lon']}"
response = requests.get(point_url, headers=self.headers)
if response.status_code != 200:
return {"error": "Unable to fetch weather data."}
grid_data = response.json()
forecast_url = grid_data["properties"]["forecast"]
response = requests.get(forecast_url, headers=self.headers)
if response.status_code == 200:
forecast_data = response.json()["properties"]["periods"]
weather_data = {
"type": "weather",
"location": f"{coords['place']}, {coords['state']}",
"current": forecast_data[0],
"forecast": forecast_data[1:4],
}
# Save to session state
st.session_state.weather_data = weather_data
return weather_data
return {"error": "Unable to fetch forecast data."}
def initialize_rag():
"""Initialize RAG system."""
if "rag_chain" in st.session_state:
logger.info("Using cached RAG chain from session state")
return st.session_state.rag_chain
# Try multiple possible paths for the PDF
possible_paths = [
"Grid_Code.pdf", # Base directory (local and Docker)
"/app/Grid_Code.pdf", # Docker container path
Path(__file__).parent / "Grid_Code.pdf", # Absolute path
]
data_path = None
for path in possible_paths:
if isinstance(path, str):
path = Path(path)
logger.info(f"Checking path: {path}")
if path.exists():
data_path = str(path)
logger.info(f"Found PDF at: {data_path}")
break
if not data_path:
raise FileNotFoundError(
f"PDF not found in any of these locations: {possible_paths}"
)
with st.spinner("Loading Grid Code documents..."):
loader = GridCodeLoader(data_path, pages=25)
documents = loader.load_and_split()
logger.info(f"Loaded {len(documents)} document chunks")
with st.spinner("Creating vector store..."):
embedding_model = EmbeddingModel()
vectorstore = VectorStore(embedding_model)
vectorstore = vectorstore.create_vectorstore(documents)
logger.info("Vector store created successfully")
# Cache the RAG chain in session state
rag_chain = RAGChain(vectorstore)
st.session_state.rag_chain = rag_chain
return rag_chain
class RAGTool:
def __init__(self, rag_chain):
self.rag_chain = rag_chain
def run(self, question: str) -> str:
"""Answer questions using the Grid Code."""
response = self.rag_chain.invoke(question)
return response["answer"]
class AgentState(TypedDict):
"""State definition for the agent."""
messages: Annotated[list, add_messages]
def create_agent_workflow(rag_chain, weather_tool):
"""Create an agent that can use both RAG and weather tools."""
# Define the tools
tools = [
Tool(
name="grid_code_query",
description="Answer questions about the Grid Code and electrical regulations",
func=lambda q: rag_chain.invoke(q)["answer"],
),
Tool(
name="get_weather",
description="Get weather forecast for a ZIP code. Input should be a 5-digit ZIP code.",
func=lambda z: weather_tool.run(z),
),
]
# Initialize the LLM
llm = ChatOpenAI(model="gpt-4o", temperature=0)
# Create the custom prompt
prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(
"""You are a helpful assistant that specializes in two areas:
1. Answering questions about electrical Grid Code regulations
2. Providing weather information for specific locations
For weather queries:
- Extract the ZIP code from the question
- Use the get_weather tool to fetch the forecast
For Grid Code questions:
- Use the grid_code_query tool to find relevant information
- If the information isn't in the Grid Code, clearly state that
- Provide specific references when possible
"""
),
MessagesPlaceholder(variable_name="chat_history", optional=True),
HumanMessagePromptTemplate.from_template("{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
# Create the agent
agent = create_tool_calling_agent(llm, tools, prompt)
return AgentExecutor(
agent=agent,
tools=tools,
verbose=True,
handle_parsing_errors=True,
)
def display_weather(weather_data):
"""Display weather information in a nice format"""
if "error" in weather_data:
st.error(weather_data["error"])
return
if weather_data.get("type") == "weather":
# Location header
st.header(f"Weather for {weather_data['location']}")
# Current conditions
current = weather_data["current"]
st.subheader("Current Conditions")
# Use columns for current weather layout
col1, col2 = st.columns(2)
with col1:
# Temperature display with metric
st.metric(
"Temperature", f"{current['temperature']}°{current['temperatureUnit']}"
)
# Wind information
st.info(f"💨 Wind: {current['windSpeed']} {current['windDirection']}")
with col2:
# Current forecast
st.markdown(f"**🌤️ Conditions:** {current['shortForecast']}")
st.markdown(f"**📝 Details:** {current['detailedForecast']}")
# Extended forecast
st.subheader("Extended Forecast")
for period in weather_data["forecast"]:
with st.expander(f"📅 {period['name']}"):
st.markdown(
f"**🌡️ Temperature:** {period['temperature']}°{period['temperatureUnit']}"
)
st.markdown(
f"**💨 Wind:** {period['windSpeed']} {period['windDirection']}"
)
st.markdown(f"**🌤️ Forecast:** {period['shortForecast']}")
st.markdown(f"**📝 Details:** {period['detailedForecast']}")
def main():
image = Image.open("app/src/data/logo.png")
st.image(image, use_container_width=True)
# Initialize if not in session state
if "app" not in st.session_state:
rag_chain = initialize_rag()
weather_tool = WeatherTool()
st.session_state.app = create_agent_workflow(rag_chain, weather_tool)
# Initialize session states if not present
if "response" not in st.session_state:
st.session_state.response = None
# Initialize ALL state variables in a single place for clarity
for key in [
"transcription",
"process_input",
"audio_recorded",
"audio_bytes",
"clear_feedback",
]:
if key not in st.session_state:
st.session_state[key] = None
# Create input area
input_container = st.container()
with input_container:
st.write("Type your question:")
# Create columns for input field, send button, and mic button
col1, col2, col3 = st.columns([4, 1, 1])
with col1:
# Text input
user_input = st.text_input(
"",
key="typed_input",
label_visibility="collapsed",
)
with col2:
# Send button for text input
send_pressed = st.button("Send")
with col3:
# Audio recorder inline with input
new_audio_bytes = audio_recorder(text="", icon_size="2x")
# If new audio is recorded, immediately reset clear_feedback flag
# This ensures transcription will be shown right away
if new_audio_bytes:
st.session_state.clear_feedback = False
# Full width container for feedback (transcription, spinner, etc.)
feedback_container = st.empty() # Use st.empty() for easy clearing
# Handle new audio recording - always process if there's new audio
if new_audio_bytes:
# Store audio for processing
st.session_state.audio_bytes = new_audio_bytes
st.session_state.audio_recorded = True
with feedback_container.container():
st.audio(new_audio_bytes, format="audio/wav")
# Save audio to a temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
tmp_file.write(new_audio_bytes)
tmp_path = tmp_file.name
# Transcribe using Whisper API
try:
with st.spinner("Transcribing..."):
client = openai.OpenAI()
with open(tmp_path, "rb") as audio_file:
transcript = client.audio.transcriptions.create(
model="whisper-1", file=audio_file
)
# Store the transcribed text
transcribed_text = transcript.text
if transcribed_text.strip():
st.session_state.transcription = transcribed_text
except Exception as e:
st.error(f"Error transcribing audio: {e}")
else:
# No new audio this run
st.session_state.audio_recorded = False
# Only show transcription info if not marked for clearing AND we have a transcription
if not st.session_state.clear_feedback and st.session_state.transcription:
with feedback_container.container():
st.info(f"Transcribed: {st.session_state.transcription}")
# Center the button
col1, col2, col3 = st.columns([1.5, 2, 1.5])
with col2:
use_transcription = st.button(
"Use this transcription", key="use_transcript_btn"
)
else:
use_transcription = False
# If we were clearing feedback but now we don't have new audio,
# reset the clear flag for next time
if st.session_state.clear_feedback and not new_audio_bytes:
st.session_state.clear_feedback = False
# Determine if we need to process input
process_text = send_pressed and user_input
process_transcription = use_transcription and st.session_state.transcription
# Clear the feedback container if Send is clicked with text input
if send_pressed and user_input:
feedback_container.empty()
st.session_state.clear_feedback = True
# Reset process_input at the start of each interaction
st.session_state.process_input = None
# Set the input to process based on what was submitted
if process_text:
st.session_state.process_input = user_input
# Clear transcription when sending text input
st.session_state.transcription = None
elif process_transcription:
st.session_state.process_input = st.session_state.transcription
# Clear transcription after using it
st.session_state.transcription = None
# Set flag to clear feedback container on next run
st.session_state.clear_feedback = True
# Clear the container immediately
feedback_container.empty()
# Process input if available
if st.session_state.process_input:
processing_container = st.container()
with processing_container:
with st.spinner("Processing your request..."):
result = st.session_state.app.invoke(
{"input": st.session_state.process_input}
)
# Check if we have weather data in session state
if "weather_data" in st.session_state:
st.session_state.response = {
"type": "weather",
"data": st.session_state.weather_data,
}
del st.session_state.weather_data
else:
st.session_state.response = {
"type": "text",
"data": result["output"],
}
# Clear input after processing
st.session_state.process_input = None
# Display response in full width container
if st.session_state.response:
st.markdown("---") # Add a separator
# Use a container for the full-width response
response_container = st.container()
with response_container:
if st.session_state.response["type"] == "weather":
display_weather(st.session_state.response["data"])
else:
st.write(st.session_state.response["data"])
if __name__ == "__main__":
main()