Update app.py
Browse files
app.py
CHANGED
|
@@ -4,9 +4,11 @@ import json
|
|
| 4 |
import requests
|
| 5 |
import traceback
|
| 6 |
|
|
|
|
| 7 |
st.set_page_config(page_title="JSON-Backed AI Chat Agent", layout="wide")
|
| 8 |
st.title("JSON-Backed AI Chat Agent")
|
| 9 |
|
|
|
|
| 10 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 11 |
if not OPENAI_API_KEY:
|
| 12 |
st.error("❌ OPENAI_API_KEY not set in Settings → Secrets.")
|
|
@@ -51,31 +53,33 @@ if uploaded_files:
|
|
| 51 |
system_message = {
|
| 52 |
"role": "system",
|
| 53 |
"content": (
|
| 54 |
-
"You are an AI data analyst for the
|
| 55 |
-
"
|
| 56 |
-
"
|
| 57 |
-
"
|
| 58 |
-
"
|
| 59 |
-
"If the
|
| 60 |
-
)
|
| 61 |
}
|
| 62 |
st.session_state.messages = [system_message]
|
| 63 |
else:
|
| 64 |
st.session_state.json_data.clear()
|
| 65 |
|
| 66 |
def search_json(file_name, key, value):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
try:
|
| 68 |
data = st.session_state.json_data[file_name]
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
return results[:10]
|
| 72 |
-
elif isinstance(data, dict):
|
| 73 |
-
if key in data and str(data[key]) == str(value):
|
| 74 |
-
return [{key: value}]
|
| 75 |
-
else:
|
| 76 |
-
return []
|
| 77 |
-
else:
|
| 78 |
-
return []
|
| 79 |
except Exception as e:
|
| 80 |
return {"error": str(e)}
|
| 81 |
|
|
@@ -103,6 +107,22 @@ def count_key_occurrences(file_name, key):
|
|
| 103 |
except Exception as e:
|
| 104 |
return {"error": str(e)}
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
function_schema = [
|
| 107 |
{
|
| 108 |
"name": "search_json",
|
|
@@ -140,6 +160,18 @@ function_schema = [
|
|
| 140 |
"required": ["file_name", "key"],
|
| 141 |
},
|
| 142 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
]
|
| 144 |
|
| 145 |
st.markdown("### Conversation")
|
|
@@ -192,13 +224,13 @@ def send_message():
|
|
| 192 |
args = json.loads(args_json)
|
| 193 |
|
| 194 |
if func_name == "search_json":
|
| 195 |
-
result = search_json(
|
| 196 |
-
args.get("file_name"), args.get("key"), args.get("value")
|
| 197 |
-
)
|
| 198 |
elif func_name == "list_keys":
|
| 199 |
result = list_keys(args.get("file_name"))
|
| 200 |
elif func_name == "count_key_occurrences":
|
| 201 |
result = count_key_occurrences(args.get("file_name"), args.get("key"))
|
|
|
|
|
|
|
| 202 |
else:
|
| 203 |
result = {"error": f"Unknown function: {func_name}"}
|
| 204 |
|
|
|
|
| 4 |
import requests
|
| 5 |
import traceback
|
| 6 |
|
| 7 |
+
# --- Page config
|
| 8 |
st.set_page_config(page_title="JSON-Backed AI Chat Agent", layout="wide")
|
| 9 |
st.title("JSON-Backed AI Chat Agent")
|
| 10 |
|
| 11 |
+
# --- Load API key
|
| 12 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 13 |
if not OPENAI_API_KEY:
|
| 14 |
st.error("❌ OPENAI_API_KEY not set in Settings → Secrets.")
|
|
|
|
| 53 |
system_message = {
|
| 54 |
"role": "system",
|
| 55 |
"content": (
|
| 56 |
+
"You are an AI data analyst for the uploaded JSON files. "
|
| 57 |
+
"Each file may have a different structure and set of keys. "
|
| 58 |
+
"When the user asks a question about all data (like 'How many females are there?'), "
|
| 59 |
+
"use the 'search_all_jsons' function to search through every uploaded file recursively. "
|
| 60 |
+
"Return results with the count and details, grouped by file if needed. "
|
| 61 |
+
"If the question is specific to one file, use the appropriate function for that file."
|
| 62 |
+
)
|
| 63 |
}
|
| 64 |
st.session_state.messages = [system_message]
|
| 65 |
else:
|
| 66 |
st.session_state.json_data.clear()
|
| 67 |
|
| 68 |
def search_json(file_name, key, value):
|
| 69 |
+
def recursive_search(obj, key, value, found):
|
| 70 |
+
if isinstance(obj, dict):
|
| 71 |
+
if key in obj and str(obj[key]).lower() == str(value).lower():
|
| 72 |
+
found.append(obj)
|
| 73 |
+
for v in obj.values():
|
| 74 |
+
recursive_search(v, key, value, found)
|
| 75 |
+
elif isinstance(obj, list):
|
| 76 |
+
for item in obj:
|
| 77 |
+
recursive_search(item, key, value, found)
|
| 78 |
+
return found
|
| 79 |
try:
|
| 80 |
data = st.session_state.json_data[file_name]
|
| 81 |
+
results = recursive_search(data, key, value, [])
|
| 82 |
+
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
except Exception as e:
|
| 84 |
return {"error": str(e)}
|
| 85 |
|
|
|
|
| 107 |
except Exception as e:
|
| 108 |
return {"error": str(e)}
|
| 109 |
|
| 110 |
+
def search_all_jsons(key, value):
|
| 111 |
+
"""Search all JSON files for dicts with key==value, recursively."""
|
| 112 |
+
found = []
|
| 113 |
+
for file_name, data in st.session_state.json_data.items():
|
| 114 |
+
def recursive_search(obj):
|
| 115 |
+
if isinstance(obj, dict):
|
| 116 |
+
if key in obj and str(obj[key]).lower() == str(value).lower():
|
| 117 |
+
found.append({**obj, "__file__": file_name})
|
| 118 |
+
for v in obj.values():
|
| 119 |
+
recursive_search(v)
|
| 120 |
+
elif isinstance(obj, list):
|
| 121 |
+
for item in obj:
|
| 122 |
+
recursive_search(item)
|
| 123 |
+
recursive_search(data)
|
| 124 |
+
return found
|
| 125 |
+
|
| 126 |
function_schema = [
|
| 127 |
{
|
| 128 |
"name": "search_json",
|
|
|
|
| 160 |
"required": ["file_name", "key"],
|
| 161 |
},
|
| 162 |
},
|
| 163 |
+
{
|
| 164 |
+
"name": "search_all_jsons",
|
| 165 |
+
"description": "Search all uploaded JSON files recursively for dicts where a key matches a value.",
|
| 166 |
+
"parameters": {
|
| 167 |
+
"type": "object",
|
| 168 |
+
"properties": {
|
| 169 |
+
"key": {"type": "string", "description": "The key to search for (e.g. 'gender')"},
|
| 170 |
+
"value": {"type": "string", "description": "The value to match (e.g. 'female')"}
|
| 171 |
+
},
|
| 172 |
+
"required": ["key", "value"]
|
| 173 |
+
}
|
| 174 |
+
}
|
| 175 |
]
|
| 176 |
|
| 177 |
st.markdown("### Conversation")
|
|
|
|
| 224 |
args = json.loads(args_json)
|
| 225 |
|
| 226 |
if func_name == "search_json":
|
| 227 |
+
result = search_json(args.get("file_name"), args.get("key"), args.get("value"))
|
|
|
|
|
|
|
| 228 |
elif func_name == "list_keys":
|
| 229 |
result = list_keys(args.get("file_name"))
|
| 230 |
elif func_name == "count_key_occurrences":
|
| 231 |
result = count_key_occurrences(args.get("file_name"), args.get("key"))
|
| 232 |
+
elif func_name == "search_all_jsons":
|
| 233 |
+
result = search_all_jsons(args.get("key"), args.get("value"))
|
| 234 |
else:
|
| 235 |
result = {"error": f"Unknown function: {func_name}"}
|
| 236 |
|