Jurabek commited on
Commit
3c72163
Β·
verified Β·
1 Parent(s): 4956423

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +64 -47
src/streamlit_app.py CHANGED
@@ -1,7 +1,10 @@
1
  import random
2
  import uuid
3
  import logging
 
4
  import streamlit as st
 
 
5
 
6
  # ==================================================
7
  # LOGGING
@@ -13,16 +16,23 @@ logging.basicConfig(
13
  logger = logging.getLogger("AI-Agent")
14
 
15
  # ==================================================
16
- # RAW DATA (600+ ROWS, IN-MEMORY)
 
 
 
 
 
17
  # ==================================================
18
  @st.cache_data
19
  def generate_data(rows: int = 600):
20
  random.seed(42)
 
21
  data = []
22
 
23
  for i in range(1, rows + 1):
24
  data.append({
25
  "order_id": i,
 
26
  "supplier": random.choice(
27
  ["Supplier A", "Supplier B", "Supplier C", "Supplier D"]
28
  ),
@@ -37,7 +47,7 @@ def generate_data(rows: int = 600):
37
  )
38
  })
39
 
40
- logger.info("Generated %d rows of raw data", len(data))
41
  return data
42
 
43
  DATA = generate_data()
@@ -46,10 +56,9 @@ DATA = generate_data()
46
  # TOOLS (FUNCTION CALLING)
47
  # ==================================================
48
  def tool_get_stats():
49
- logger.info("Tool β†’ get_stats")
50
-
51
  prices = [x["price"] for x in DATA]
52
  days = [x["delivery_days"] for x in DATA]
 
53
 
54
  return {
55
  "rows": len(DATA),
@@ -59,24 +68,16 @@ def tool_get_stats():
59
 
60
  def tool_query_product(product: str):
61
  logger.info("Tool β†’ query_product(%s)", product)
 
62
 
63
- return [
64
- {
65
- "order_id": x["order_id"],
66
- "supplier": x["supplier"],
67
- "quantity": x["quantity"],
68
- "price": x["price"],
69
- "delivery_days": x["delivery_days"],
70
- "status": x["status"],
71
- }
72
- for x in DATA
73
- if x["product"].lower() == product.lower()
74
- ][:5]
75
 
76
  def tool_create_support_ticket(text: str):
77
  ticket_id = str(uuid.uuid4())[:8]
78
  logger.info("Tool β†’ create_support_ticket (%s)", ticket_id)
79
-
80
  return {
81
  "ticket_id": ticket_id,
82
  "system": "GitHub Issues (mock)",
@@ -87,72 +88,87 @@ def tool_create_support_ticket(text: str):
87
  # ==================================================
88
  # SAFETY
89
  # ==================================================
90
- def is_dangerous(text: str) -> bool:
91
  blocked = ["delete", "drop", "truncate", "remove"]
92
  return any(b in text.lower() for b in blocked)
93
 
94
  # ==================================================
95
- # AGENT (DETERMINISTIC, TOOL-BASED)
96
  # ==================================================
97
  def agent(user_input: str):
98
  logger.info("User input: %s", user_input)
99
 
100
  if is_dangerous(user_input):
101
- logger.warning("Blocked dangerous operation")
102
  return "❌ Dangerous operations are not allowed."
103
 
104
  text = user_input.lower()
105
 
106
- # -------- FUNCTION CALL ROUTING --------
107
  if "stats" in text or "summary" in text:
108
- stats = tool_get_stats()
109
  return (
110
  f"πŸ“Š **Business Overview**\n"
111
- f"- Rows: {stats['rows']}\n"
112
- f"- Avg Price: ${stats['avg_price']}\n"
113
- f"- Avg Delivery Days: {stats['avg_delivery_days']}"
114
  )
115
 
 
 
 
 
 
 
 
 
 
 
116
  if text.startswith("show"):
117
  product = user_input.replace("show", "").strip()
118
  rows = tool_query_product(product)
119
-
120
  if not rows:
121
- return (
122
- "No records found for this product.\n\n"
123
- "Would you like me to create a support ticket?"
124
- )
125
  return rows
126
 
 
127
  if "support" in text or "ticket" in text:
128
- ticket = tool_create_support_ticket(user_input)
129
  return (
130
  f"🎫 **Support Ticket Created**\n"
131
- f"- ID: {ticket['ticket_id']}\n"
132
- f"- System: {ticket['system']}\n"
133
- f"- Status: {ticket['status']}"
134
  )
135
 
136
- # -------- FALLBACK (NO LLM) --------
137
- logger.info("Fallback response (LLM disabled)")
138
-
139
- return (
140
- "πŸ€– I can help you with:\n"
141
- "- database statistics (`stats`)\n"
142
- "- product queries (`show tomato`)\n"
143
- "- creating support tickets\n\n"
144
- "Please try one of the supported actions."
 
 
 
 
 
 
 
 
 
145
  )
146
 
 
 
147
  # ==================================================
148
  # STREAMLIT UI
149
  # ==================================================
150
  st.set_page_config(page_title="AI Procurement Agent", layout="wide")
151
- st.title("πŸ€– AI Procurement Agent β€” Single File")
152
 
153
- # -------- Sidebar: Business Info --------
154
  stats = tool_get_stats()
155
-
156
  st.sidebar.header("πŸ“Š Business Information")
157
  st.sidebar.metric("Rows", stats["rows"])
158
  st.sidebar.metric("Avg Price", f"${stats['avg_price']}")
@@ -161,11 +177,12 @@ st.sidebar.metric("Avg Delivery Days", stats["avg_delivery_days"])
161
  st.sidebar.markdown("### Sample queries")
162
  st.sidebar.code(
163
  "show tomato\n"
 
164
  "database stats\n"
165
  "create support ticket"
166
  )
167
 
168
- # -------- Chat --------
169
  if "messages" not in st.session_state:
170
  st.session_state.messages = []
171
 
 
1
  import random
2
  import uuid
3
  import logging
4
+ from datetime import datetime, timedelta
5
  import streamlit as st
6
+ from openai import OpenAI
7
+ import os
8
 
9
  # ==================================================
10
  # LOGGING
 
16
  logger = logging.getLogger("AI-Agent")
17
 
18
  # ==================================================
19
+ # OPENAI CLIENT
20
+ # ==================================================
21
+ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
22
+
23
+ # ==================================================
24
+ # RAW DATA (600+ ROWS WITH DATE)
25
  # ==================================================
26
  @st.cache_data
27
  def generate_data(rows: int = 600):
28
  random.seed(42)
29
+ today = datetime.today().date()
30
  data = []
31
 
32
  for i in range(1, rows + 1):
33
  data.append({
34
  "order_id": i,
35
+ "order_date": today - timedelta(days=random.randint(0, 30)),
36
  "supplier": random.choice(
37
  ["Supplier A", "Supplier B", "Supplier C", "Supplier D"]
38
  ),
 
47
  )
48
  })
49
 
50
+ logger.info("Generated %d rows of data", len(data))
51
  return data
52
 
53
  DATA = generate_data()
 
56
  # TOOLS (FUNCTION CALLING)
57
  # ==================================================
58
  def tool_get_stats():
 
 
59
  prices = [x["price"] for x in DATA]
60
  days = [x["delivery_days"] for x in DATA]
61
+ logger.info("Tool β†’ get_stats")
62
 
63
  return {
64
  "rows": len(DATA),
 
68
 
69
  def tool_query_product(product: str):
70
  logger.info("Tool β†’ query_product(%s)", product)
71
+ return [x for x in DATA if x["product"].lower() == product.lower()][:5]
72
 
73
+ def tool_query_last_days(days: int):
74
+ logger.info("Tool β†’ query_last_days(%d)", days)
75
+ cutoff = datetime.today().date() - timedelta(days=days)
76
+ return [x for x in DATA if x["order_date"] >= cutoff][:10]
 
 
 
 
 
 
 
 
77
 
78
  def tool_create_support_ticket(text: str):
79
  ticket_id = str(uuid.uuid4())[:8]
80
  logger.info("Tool β†’ create_support_ticket (%s)", ticket_id)
 
81
  return {
82
  "ticket_id": ticket_id,
83
  "system": "GitHub Issues (mock)",
 
88
  # ==================================================
89
  # SAFETY
90
  # ==================================================
91
+ def is_dangerous(text: str):
92
  blocked = ["delete", "drop", "truncate", "remove"]
93
  return any(b in text.lower() for b in blocked)
94
 
95
  # ==================================================
96
+ # AGENT
97
  # ==================================================
98
  def agent(user_input: str):
99
  logger.info("User input: %s", user_input)
100
 
101
  if is_dangerous(user_input):
 
102
  return "❌ Dangerous operations are not allowed."
103
 
104
  text = user_input.lower()
105
 
106
+ # ---- STATS ----
107
  if "stats" in text or "summary" in text:
108
+ s = tool_get_stats()
109
  return (
110
  f"πŸ“Š **Business Overview**\n"
111
+ f"- Rows: {s['rows']}\n"
112
+ f"- Avg Price: ${s['avg_price']}\n"
113
+ f"- Avg Delivery Days: {s['avg_delivery_days']}"
114
  )
115
 
116
+ # ---- LAST N DAYS ----
117
+ if "last" in text and "day" in text:
118
+ try:
119
+ days = int([x for x in text.split() if x.isdigit()][0])
120
+ except:
121
+ days = 3
122
+ rows = tool_query_last_days(days)
123
+ return rows if rows else f"No data for last {days} days."
124
+
125
+ # ---- PRODUCT ----
126
  if text.startswith("show"):
127
  product = user_input.replace("show", "").strip()
128
  rows = tool_query_product(product)
 
129
  if not rows:
130
+ return "No data found. Would you like to create a support ticket?"
 
 
 
131
  return rows
132
 
133
+ # ---- SUPPORT ----
134
  if "support" in text or "ticket" in text:
135
+ t = tool_create_support_ticket(user_input)
136
  return (
137
  f"🎫 **Support Ticket Created**\n"
138
+ f"ID: {t['ticket_id']}\n"
139
+ f"System: {t['system']}"
 
140
  )
141
 
142
+ # ---- OPENAI FALLBACK ----
143
+ logger.info("LLM fallback via OpenAI")
144
+
145
+ response = client.chat.completions.create(
146
+ model="gpt-4o-mini",
147
+ messages=[
148
+ {
149
+ "role": "system",
150
+ "content": (
151
+ "You are an AI procurement assistant. "
152
+ "You do NOT have direct access to raw data. "
153
+ "Suggest support if needed."
154
+ )
155
+ },
156
+ {"role": "user", "content": user_input}
157
+ ],
158
+ max_tokens=150,
159
+ temperature=0.3
160
  )
161
 
162
+ return response.choices[0].message.content
163
+
164
  # ==================================================
165
  # STREAMLIT UI
166
  # ==================================================
167
  st.set_page_config(page_title="AI Procurement Agent", layout="wide")
168
+ st.title("πŸ€– AI Procurement Agent (OpenAI powered)")
169
 
170
+ # Sidebar
171
  stats = tool_get_stats()
 
172
  st.sidebar.header("πŸ“Š Business Information")
173
  st.sidebar.metric("Rows", stats["rows"])
174
  st.sidebar.metric("Avg Price", f"${stats['avg_price']}")
 
177
  st.sidebar.markdown("### Sample queries")
178
  st.sidebar.code(
179
  "show tomato\n"
180
+ "last 3 days\n"
181
  "database stats\n"
182
  "create support ticket"
183
  )
184
 
185
+ # Chat
186
  if "messages" not in st.session_state:
187
  st.session_state.messages = []
188