LightRT commited on
Commit
dd36c44
Β·
verified Β·
1 Parent(s): 0a85597

Rename src/streamlit_app.py to src/app.py

Browse files
Files changed (2) hide show
  1. src/app.py +114 -0
  2. src/streamlit_app.py +0 -40
src/app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+ import uuid
4
+ import time
5
+
6
+ # --- CONFIGURATION ---
7
+ API_URL = "https://lightrt-text2sql-backend.hf.space" # Your FastAPI server URL
8
+
9
+ st.set_page_config(page_title="Text@SQL Agent", page_icon="πŸ€–", layout="centered")
10
+
11
+ # --- SESSION STATE INITIALIZATION ---
12
+ # This ensures variables survive when Streamlit re-renders the page
13
+ if "thread_id" not in st.session_state:
14
+ st.session_state.thread_id = str(uuid.uuid4()) # Unique session ID for LangGraph memory
15
+ if "user_id" not in st.session_state:
16
+ st.session_state.user_id = "tenant_" + str(uuid.uuid4())[:8]
17
+ if "is_db_connected" not in st.session_state:
18
+ st.session_state.is_db_connected = False
19
+ if "connection_url" not in st.session_state:
20
+ st.session_state.connection_url = ""
21
+ if "chat_history" not in st.session_state:
22
+ st.session_state.chat_history = []
23
+
24
+ # --- SIDEBAR: DATABASE CONNECTION ---
25
+ with st.sidebar:
26
+ st.header("βš™οΈ Database Setup")
27
+
28
+ # If already connected, disable the input to enforce ONE database connection
29
+ db_input = st.text_input(
30
+ "Enter Database URL:",
31
+ disabled=st.session_state.is_db_connected
32
+ )
33
+
34
+ if not st.session_state.is_db_connected:
35
+ if st.button("Connect & Initialize", type="primary", use_container_width=True):
36
+ if not db_input:
37
+ st.error("Please enter a valid URL.")
38
+ else:
39
+ with st.spinner("Building embeddings and initializing agent..."):
40
+ try:
41
+ # 1. Hit your FastAPI upload endpoint
42
+ payload = {"connection_url": db_input, "user_id": st.session_state.user_id}
43
+ response = requests.post(f"{API_URL}/upload_url", json=payload)
44
+
45
+ if response.status_code == 200:
46
+ # 2. Lock the connection and unlock the chat
47
+ st.session_state.is_db_connected = True
48
+ st.session_state.connection_url = db_input
49
+
50
+ # Because your FastAPI upload uses BackgroundTasks, it returns instantly.
51
+ # We add a 2-second UI buffer here so the Qdrant embeddings have time to finish
52
+ # before the user fires off their first chat question.
53
+ time.sleep(15)
54
+
55
+ st.success("Database connected securely!")
56
+ st.rerun() # Refresh UI to unlock the chat window
57
+ else:
58
+ st.error(f"Failed to connect: {response.text}")
59
+ except requests.exceptions.ConnectionError:
60
+ st.error("🚨 Cannot connect to backend. Is FastAPI running?")
61
+ else:
62
+ st.success("βœ… Connected to Database")
63
+ st.caption(f"URL: {st.session_state.connection_url}")
64
+
65
+ # Add a reset button just in case they want to start completely over
66
+ if st.button("Disconnect & Reset", use_container_width=True):
67
+ st.session_state.clear()
68
+ st.rerun()
69
+
70
+ # --- MAIN CHAT INTERFACE ---
71
+ st.title("πŸ—£οΈ Text2SQL Agent")
72
+
73
+ # The Lock: Do not render the chat if DB is not connected
74
+ if not st.session_state.is_db_connected:
75
+ st.info("πŸ‘ˆ Please connect your database in the sidebar to begin analyzing data.")
76
+ else:
77
+ # 1. Display previous chat messages from session state
78
+ for msg in st.session_state.chat_history:
79
+ with st.chat_message(msg["role"]):
80
+ st.markdown(msg["content"])
81
+
82
+ # 2. The Chat Input box
83
+ if user_query := st.chat_input("Ask a question about your data..."):
84
+
85
+ # Immediately display the user's question in the UI
86
+ st.session_state.chat_history.append({"role": "user", "content": user_query})
87
+ with st.chat_message("user"):
88
+ st.markdown(user_query)
89
+
90
+ # 3. Call the LangGraph Backend
91
+ with st.chat_message("assistant"):
92
+ with st.spinner("Analyzing schema and generating SQL..."):
93
+ try:
94
+ payload = {
95
+ "message": user_query,
96
+ "thread_id": st.session_state.thread_id,
97
+ "user_id": st.session_state.user_id,
98
+ "connection_url": st.session_state.connection_url
99
+ }
100
+
101
+ response = requests.post(f"{API_URL}/chat", json=payload)
102
+
103
+ if response.status_code == 200:
104
+ # Extract the final_result from your FastAPI JSON response
105
+ answer = response.json().get("response", "No response found.")
106
+ st.markdown(answer)
107
+
108
+ # Save the assistant's answer to the UI history
109
+ st.session_state.chat_history.append({"role": "assistant", "content": answer})
110
+ else:
111
+ st.error(f"Agent Error: {response.text}")
112
+
113
+ except requests.exceptions.ConnectionError:
114
+ st.error("🚨 Connection dropped. Ensure FastAPI is running.")
src/streamlit_app.py DELETED
@@ -1,40 +0,0 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))