Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| import sys | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Annotated, TypedDict | |
| import openai | |
| import requests | |
| import streamlit as st | |
| from audio_recorder_streamlit import audio_recorder | |
| from langchain.agents import AgentExecutor, create_tool_calling_agent | |
| from langchain_core.prompts import ( | |
| ChatPromptTemplate, | |
| HumanMessagePromptTemplate, | |
| MessagesPlaceholder, | |
| SystemMessagePromptTemplate, | |
| ) | |
| from langchain_core.tools import Tool | |
| from langchain_openai import ChatOpenAI | |
| from langgraph.graph.message import add_messages | |
| from PIL import Image | |
| sys.path.append(str(Path(__file__).parent)) | |
| from dotenv import load_dotenv | |
| # Now import from app.src | |
| from app.src.embedding.model import EmbeddingModel | |
| from app.src.rag.chain import RAGChain | |
| from app.src.rag.document_loader import GridCodeLoader | |
| from app.src.rag.vectorstore import VectorStore | |
| # Load .env file from base directory | |
| load_dotenv(Path(__file__).parent / ".env") | |
| logger = logging.getLogger(__name__) | |
| def get_secrets(): | |
| """Get secrets from environment variables.""" | |
| # Skip trying Streamlit secrets and go straight to environment variables | |
| return { | |
| "OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"), | |
| "LANGCHAIN_API_KEY": os.getenv("LANGCHAIN_API_KEY"), | |
| "LANGCHAIN_PROJECT": os.getenv("LANGCHAIN_PROJECT", "GridGuide"), | |
| "LANGCHAIN_TRACING_V2": os.getenv("LANGCHAIN_TRACING_V2", "true"), | |
| } | |
| # Set up environment variables from secrets | |
| secrets = get_secrets() | |
| for key, value in secrets.items(): | |
| if value: | |
| os.environ[key] = value | |
| # Verify API keys without showing warning | |
| if not os.getenv("OPENAI_API_KEY"): | |
| st.error("OpenAI API key not found. Please check your .env file.") | |
| st.stop() | |
| class WeatherTool: | |
| def __init__(self): | |
| self.base_url = "https://api.weather.gov" | |
| self.headers = { | |
| "User-Agent": "(Grid Code Assistant, contact@example.com)", | |
| "Accept": "application/json", | |
| } | |
| def get_coordinates_from_zip(self, zipcode): | |
| response = requests.get(f"https://api.zippopotam.us/us/{zipcode}") | |
| if response.status_code == 200: | |
| data = response.json() | |
| return { | |
| "lat": data["places"][0]["latitude"], | |
| "lon": data["places"][0]["longitude"], | |
| "place": data["places"][0]["place name"], | |
| "state": data["places"][0]["state"], | |
| } | |
| return None | |
| def run(self, zipcode): | |
| coords = self.get_coordinates_from_zip(zipcode) | |
| if not coords: | |
| return {"error": "Invalid ZIP code or unable to get coordinates."} | |
| point_url = f"{self.base_url}/points/{coords['lat']},{coords['lon']}" | |
| response = requests.get(point_url, headers=self.headers) | |
| if response.status_code != 200: | |
| return {"error": "Unable to fetch weather data."} | |
| grid_data = response.json() | |
| forecast_url = grid_data["properties"]["forecast"] | |
| response = requests.get(forecast_url, headers=self.headers) | |
| if response.status_code == 200: | |
| forecast_data = response.json()["properties"]["periods"] | |
| weather_data = { | |
| "type": "weather", | |
| "location": f"{coords['place']}, {coords['state']}", | |
| "current": forecast_data[0], | |
| "forecast": forecast_data[1:4], | |
| } | |
| # Save to session state | |
| st.session_state.weather_data = weather_data | |
| return weather_data | |
| return {"error": "Unable to fetch forecast data."} | |
| def initialize_rag(): | |
| """Initialize RAG system.""" | |
| if "rag_chain" in st.session_state: | |
| logger.info("Using cached RAG chain from session state") | |
| return st.session_state.rag_chain | |
| # Try multiple possible paths for the PDF | |
| possible_paths = [ | |
| "Grid_Code.pdf", # Base directory (local and Docker) | |
| "/app/Grid_Code.pdf", # Docker container path | |
| Path(__file__).parent / "Grid_Code.pdf", # Absolute path | |
| ] | |
| data_path = None | |
| for path in possible_paths: | |
| if isinstance(path, str): | |
| path = Path(path) | |
| logger.info(f"Checking path: {path}") | |
| if path.exists(): | |
| data_path = str(path) | |
| logger.info(f"Found PDF at: {data_path}") | |
| break | |
| if not data_path: | |
| raise FileNotFoundError( | |
| f"PDF not found in any of these locations: {possible_paths}" | |
| ) | |
| with st.spinner("Loading Grid Code documents..."): | |
| loader = GridCodeLoader(data_path, pages=25) | |
| documents = loader.load_and_split() | |
| logger.info(f"Loaded {len(documents)} document chunks") | |
| with st.spinner("Creating vector store..."): | |
| embedding_model = EmbeddingModel() | |
| vectorstore = VectorStore(embedding_model) | |
| vectorstore = vectorstore.create_vectorstore(documents) | |
| logger.info("Vector store created successfully") | |
| # Cache the RAG chain in session state | |
| rag_chain = RAGChain(vectorstore) | |
| st.session_state.rag_chain = rag_chain | |
| return rag_chain | |
| class RAGTool: | |
| def __init__(self, rag_chain): | |
| self.rag_chain = rag_chain | |
| def run(self, question: str) -> str: | |
| """Answer questions using the Grid Code.""" | |
| response = self.rag_chain.invoke(question) | |
| return response["answer"] | |
| class AgentState(TypedDict): | |
| """State definition for the agent.""" | |
| messages: Annotated[list, add_messages] | |
| def create_agent_workflow(rag_chain, weather_tool): | |
| """Create an agent that can use both RAG and weather tools.""" | |
| # Define the tools | |
| tools = [ | |
| Tool( | |
| name="grid_code_query", | |
| description="Answer questions about the Grid Code and electrical regulations", | |
| func=lambda q: rag_chain.invoke(q)["answer"], | |
| ), | |
| Tool( | |
| name="get_weather", | |
| description="Get weather forecast for a ZIP code. Input should be a 5-digit ZIP code.", | |
| func=lambda z: weather_tool.run(z), | |
| ), | |
| ] | |
| # Initialize the LLM | |
| llm = ChatOpenAI(model="gpt-4o", temperature=0) | |
| # Create the custom prompt | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| SystemMessagePromptTemplate.from_template( | |
| """You are a helpful assistant that specializes in two areas: | |
| 1. Answering questions about electrical Grid Code regulations | |
| 2. Providing weather information for specific locations | |
| For weather queries: | |
| - Extract the ZIP code from the question | |
| - Use the get_weather tool to fetch the forecast | |
| For Grid Code questions: | |
| - Use the grid_code_query tool to find relevant information | |
| - If the information isn't in the Grid Code, clearly state that | |
| - Provide specific references when possible | |
| """ | |
| ), | |
| MessagesPlaceholder(variable_name="chat_history", optional=True), | |
| HumanMessagePromptTemplate.from_template("{input}"), | |
| MessagesPlaceholder(variable_name="agent_scratchpad"), | |
| ] | |
| ) | |
| # Create the agent | |
| agent = create_tool_calling_agent(llm, tools, prompt) | |
| return AgentExecutor( | |
| agent=agent, | |
| tools=tools, | |
| verbose=True, | |
| handle_parsing_errors=True, | |
| ) | |
| def display_weather(weather_data): | |
| """Display weather information in a nice format""" | |
| if "error" in weather_data: | |
| st.error(weather_data["error"]) | |
| return | |
| if weather_data.get("type") == "weather": | |
| # Location header | |
| st.header(f"Weather for {weather_data['location']}") | |
| # Current conditions | |
| current = weather_data["current"] | |
| st.subheader("Current Conditions") | |
| # Use columns for current weather layout | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| # Temperature display with metric | |
| st.metric( | |
| "Temperature", f"{current['temperature']}°{current['temperatureUnit']}" | |
| ) | |
| # Wind information | |
| st.info(f"💨 Wind: {current['windSpeed']} {current['windDirection']}") | |
| with col2: | |
| # Current forecast | |
| st.markdown(f"**🌤️ Conditions:** {current['shortForecast']}") | |
| st.markdown(f"**📝 Details:** {current['detailedForecast']}") | |
| # Extended forecast | |
| st.subheader("Extended Forecast") | |
| for period in weather_data["forecast"]: | |
| with st.expander(f"📅 {period['name']}"): | |
| st.markdown( | |
| f"**🌡️ Temperature:** {period['temperature']}°{period['temperatureUnit']}" | |
| ) | |
| st.markdown( | |
| f"**💨 Wind:** {period['windSpeed']} {period['windDirection']}" | |
| ) | |
| st.markdown(f"**🌤️ Forecast:** {period['shortForecast']}") | |
| st.markdown(f"**📝 Details:** {period['detailedForecast']}") | |
| def main(): | |
| image = Image.open("app/src/data/logo.png") | |
| st.image(image, use_container_width=True) | |
| # Initialize if not in session state | |
| if "app" not in st.session_state: | |
| rag_chain = initialize_rag() | |
| weather_tool = WeatherTool() | |
| st.session_state.app = create_agent_workflow(rag_chain, weather_tool) | |
| # Initialize session states if not present | |
| if "response" not in st.session_state: | |
| st.session_state.response = None | |
| # Initialize ALL state variables in a single place for clarity | |
| for key in [ | |
| "transcription", | |
| "process_input", | |
| "audio_recorded", | |
| "audio_bytes", | |
| "clear_feedback", | |
| ]: | |
| if key not in st.session_state: | |
| st.session_state[key] = None | |
| # Create input area | |
| input_container = st.container() | |
| with input_container: | |
| st.write("Type your question:") | |
| # Create columns for input field, send button, and mic button | |
| col1, col2, col3 = st.columns([4, 1, 1]) | |
| with col1: | |
| # Text input | |
| user_input = st.text_input( | |
| "", | |
| key="typed_input", | |
| label_visibility="collapsed", | |
| ) | |
| with col2: | |
| # Send button for text input | |
| send_pressed = st.button("Send") | |
| with col3: | |
| # Audio recorder inline with input | |
| new_audio_bytes = audio_recorder(text="", icon_size="2x") | |
| # If new audio is recorded, immediately reset clear_feedback flag | |
| # This ensures transcription will be shown right away | |
| if new_audio_bytes: | |
| st.session_state.clear_feedback = False | |
| # Full width container for feedback (transcription, spinner, etc.) | |
| feedback_container = st.empty() # Use st.empty() for easy clearing | |
| # Handle new audio recording - always process if there's new audio | |
| if new_audio_bytes: | |
| # Store audio for processing | |
| st.session_state.audio_bytes = new_audio_bytes | |
| st.session_state.audio_recorded = True | |
| with feedback_container.container(): | |
| st.audio(new_audio_bytes, format="audio/wav") | |
| # Save audio to a temporary file | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: | |
| tmp_file.write(new_audio_bytes) | |
| tmp_path = tmp_file.name | |
| # Transcribe using Whisper API | |
| try: | |
| with st.spinner("Transcribing..."): | |
| client = openai.OpenAI() | |
| with open(tmp_path, "rb") as audio_file: | |
| transcript = client.audio.transcriptions.create( | |
| model="whisper-1", file=audio_file | |
| ) | |
| # Store the transcribed text | |
| transcribed_text = transcript.text | |
| if transcribed_text.strip(): | |
| st.session_state.transcription = transcribed_text | |
| except Exception as e: | |
| st.error(f"Error transcribing audio: {e}") | |
| else: | |
| # No new audio this run | |
| st.session_state.audio_recorded = False | |
| # Only show transcription info if not marked for clearing AND we have a transcription | |
| if not st.session_state.clear_feedback and st.session_state.transcription: | |
| with feedback_container.container(): | |
| st.info(f"Transcribed: {st.session_state.transcription}") | |
| # Center the button | |
| col1, col2, col3 = st.columns([1.5, 2, 1.5]) | |
| with col2: | |
| use_transcription = st.button( | |
| "Use this transcription", key="use_transcript_btn" | |
| ) | |
| else: | |
| use_transcription = False | |
| # If we were clearing feedback but now we don't have new audio, | |
| # reset the clear flag for next time | |
| if st.session_state.clear_feedback and not new_audio_bytes: | |
| st.session_state.clear_feedback = False | |
| # Determine if we need to process input | |
| process_text = send_pressed and user_input | |
| process_transcription = use_transcription and st.session_state.transcription | |
| # Clear the feedback container if Send is clicked with text input | |
| if send_pressed and user_input: | |
| feedback_container.empty() | |
| st.session_state.clear_feedback = True | |
| # Reset process_input at the start of each interaction | |
| st.session_state.process_input = None | |
| # Set the input to process based on what was submitted | |
| if process_text: | |
| st.session_state.process_input = user_input | |
| # Clear transcription when sending text input | |
| st.session_state.transcription = None | |
| elif process_transcription: | |
| st.session_state.process_input = st.session_state.transcription | |
| # Clear transcription after using it | |
| st.session_state.transcription = None | |
| # Set flag to clear feedback container on next run | |
| st.session_state.clear_feedback = True | |
| # Clear the container immediately | |
| feedback_container.empty() | |
| # Process input if available | |
| if st.session_state.process_input: | |
| processing_container = st.container() | |
| with processing_container: | |
| with st.spinner("Processing your request..."): | |
| result = st.session_state.app.invoke( | |
| {"input": st.session_state.process_input} | |
| ) | |
| # Check if we have weather data in session state | |
| if "weather_data" in st.session_state: | |
| st.session_state.response = { | |
| "type": "weather", | |
| "data": st.session_state.weather_data, | |
| } | |
| del st.session_state.weather_data | |
| else: | |
| st.session_state.response = { | |
| "type": "text", | |
| "data": result["output"], | |
| } | |
| # Clear input after processing | |
| st.session_state.process_input = None | |
| # Display response in full width container | |
| if st.session_state.response: | |
| st.markdown("---") # Add a separator | |
| # Use a container for the full-width response | |
| response_container = st.container() | |
| with response_container: | |
| if st.session_state.response["type"] == "weather": | |
| display_weather(st.session_state.response["data"]) | |
| else: | |
| st.write(st.session_state.response["data"]) | |
| if __name__ == "__main__": | |
| main() | |