Jurabek commited on
Commit
ac5adde
Β·
verified Β·
1 Parent(s): 432eed9

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +112 -64
src/streamlit_app.py CHANGED
@@ -1,51 +1,80 @@
1
- import sqlite3
2
  import uuid
3
  import logging
4
- import pandas as pd
5
  import streamlit as st
6
  from huggingface_hub import InferenceClient
7
 
8
- # ======================
9
  # LOGGING
10
- # ======================
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger("AI-Agent")
13
 
14
- # ======================
15
- # DB CONNECTION
16
- # ======================
17
- @st.cache_resource
18
- def get_db():
19
- return sqlite3.connect("procurement.db", check_same_thread=False)
20
-
21
- conn = get_db()
22
-
23
- # ======================
24
- # TOOLS (FUNCTION CALLS)
25
- # ======================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def tool_get_stats():
27
  logger.info("Tool called: get_stats")
28
- q = """
29
- SELECT COUNT(*) rows,
30
- ROUND(AVG(price),2) avg_price,
31
- ROUND(AVG(delivery_days),2) avg_delivery_days
32
- FROM orders
33
- """
34
- return pd.read_sql(q, conn).iloc[0].to_dict()
35
-
36
- def tool_query_product(product):
 
 
37
  logger.info("Tool called: query_product(%s)", product)
38
- q = """
39
- SELECT order_id, supplier, quantity, price, delivery_days, status
40
- FROM orders
41
- WHERE LOWER(product)=LOWER(?)
42
- LIMIT 5
43
- """
44
- return pd.read_sql(q, conn, params=(product,))
45
-
46
- def tool_create_ticket(text):
 
 
 
 
 
 
47
  ticket_id = str(uuid.uuid4())[:8]
48
  logger.info("Support ticket created: %s", ticket_id)
 
49
  return {
50
  "ticket_id": ticket_id,
51
  "system": "GitHub Issues (mock)",
@@ -53,67 +82,80 @@ def tool_create_ticket(text):
53
  "description": text,
54
  }
55
 
56
- # ======================
57
  # SAFETY
58
- # ======================
59
- def is_dangerous(text):
60
  blocked = ["delete", "drop", "truncate", "remove"]
61
  return any(b in text.lower() for b in blocked)
62
 
63
- # ======================
64
  # AGENT (INTENT β†’ FUNCTION CALL)
65
- # ======================
66
- def agent(user_input, client):
 
 
67
  if is_dangerous(user_input):
68
- logger.warning("Blocked dangerous request")
69
  return "❌ Dangerous operation is not allowed."
70
 
71
  text = user_input.lower()
72
 
 
73
  if "stats" in text or "summary" in text:
74
- stats = tool_get_stats()
75
  return (
76
  f"πŸ“Š **Business Overview**\n"
77
- f"- Rows: {stats['rows']}\n"
78
- f"- Avg Price: ${stats['avg_price']}\n"
79
- f"- Avg Delivery Days: {stats['avg_delivery_days']}"
80
  )
81
 
82
  if text.startswith("show"):
83
  product = user_input.replace("show", "").strip()
84
- df = tool_query_product(product)
85
- if df.empty:
86
- return "No data found. Would you like me to create a support ticket?"
87
- return df
 
 
 
 
88
 
89
  if "support" in text or "ticket" in text:
90
- ticket = tool_create_ticket(user_input)
91
  return (
92
  f"🎫 **Support Ticket Created**\n"
93
- f"ID: {ticket['ticket_id']}\n"
94
- f"System: {ticket['system']}"
 
95
  )
96
 
97
- # LLM fallback (ONLY TEXT + SMALL CONTEXT)
98
  prompt = f"""
99
  You are an AI procurement assistant.
100
- You do NOT have direct DB access.
101
- If question seems unclear, suggest contacting support.
102
 
103
  User: {user_input}
104
  Assistant:
105
  """
106
  logger.info("LLM called")
107
- return client.text_generation(prompt, max_new_tokens=150, temperature=0.3)
 
 
 
 
108
 
109
- # ======================
110
  # STREAMLIT UI
111
- # ======================
112
  st.set_page_config(page_title="AI Procurement Agent", layout="wide")
113
- st.title("πŸ€– AI Procurement Agent")
114
 
115
- # Sidebar business info
116
  stats = tool_get_stats()
 
117
  st.sidebar.metric("Rows", stats["rows"])
118
  st.sidebar.metric("Avg Price", f"${stats['avg_price']}")
119
  st.sidebar.metric("Avg Delivery Days", stats["avg_delivery_days"])
@@ -125,9 +167,15 @@ database stats
125
  create support ticket
126
  """)
127
 
128
- # HF Token
129
- hf_token = st.sidebar.text_input("HF Token", type="password")
 
 
 
 
 
130
  if not hf_token:
 
131
  st.stop()
132
 
133
  client = InferenceClient(
@@ -135,11 +183,11 @@ client = InferenceClient(
135
  token=hf_token
136
  )
137
 
138
- # Chat
139
  if "messages" not in st.session_state:
140
  st.session_state.messages = []
141
 
142
- user_input = st.chat_input("Ask the agent...")
143
 
144
  if user_input:
145
  st.session_state.messages.append(("user", user_input))
 
1
+ import random
2
  import uuid
3
  import logging
 
4
  import streamlit as st
5
  from huggingface_hub import InferenceClient
6
 
7
+ # ==================================================
8
  # LOGGING
9
+ # ==================================================
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger("AI-Agent")
12
 
13
+ # ==================================================
14
+ # RAW DATA (IN-MEMORY, 600+ ROWS)
15
+ # ==================================================
16
+ @st.cache_data
17
+ def generate_data(rows=600):
18
+ random.seed(42)
19
+ data = []
20
+
21
+ for i in range(1, rows + 1):
22
+ data.append({
23
+ "order_id": i,
24
+ "supplier": random.choice(
25
+ ["Supplier A", "Supplier B", "Supplier C", "Supplier D"]
26
+ ),
27
+ "product": random.choice(
28
+ ["Tomato", "Cheese", "Flour", "Oil", "Meat"]
29
+ ),
30
+ "quantity": random.randint(1, 100),
31
+ "price": round(random.uniform(5, 50), 2),
32
+ "delivery_days": random.randint(1, 14),
33
+ "status": random.choice(
34
+ ["Delivered", "Pending", "Delayed"]
35
+ )
36
+ })
37
+
38
+ logger.info("Raw data generated: %d rows", len(data))
39
+ return data
40
+
41
+ DATA = generate_data()
42
+
43
+ # ==================================================
44
+ # TOOLS (FUNCTION CALLING)
45
+ # ==================================================
46
  def tool_get_stats():
47
  logger.info("Tool called: get_stats")
48
+
49
+ prices = [x["price"] for x in DATA]
50
+ days = [x["delivery_days"] for x in DATA]
51
+
52
+ return {
53
+ "rows": len(DATA),
54
+ "avg_price": round(sum(prices) / len(prices), 2),
55
+ "avg_delivery_days": round(sum(days) / len(days), 2),
56
+ }
57
+
58
+ def tool_query_product(product: str):
59
  logger.info("Tool called: query_product(%s)", product)
60
+
61
+ return [
62
+ {
63
+ "order_id": x["order_id"],
64
+ "supplier": x["supplier"],
65
+ "quantity": x["quantity"],
66
+ "price": x["price"],
67
+ "delivery_days": x["delivery_days"],
68
+ "status": x["status"],
69
+ }
70
+ for x in DATA
71
+ if x["product"].lower() == product.lower()
72
+ ][:5]
73
+
74
+ def tool_create_support_ticket(text: str):
75
  ticket_id = str(uuid.uuid4())[:8]
76
  logger.info("Support ticket created: %s", ticket_id)
77
+
78
  return {
79
  "ticket_id": ticket_id,
80
  "system": "GitHub Issues (mock)",
 
82
  "description": text,
83
  }
84
 
85
+ # ==================================================
86
  # SAFETY
87
+ # ==================================================
88
+ def is_dangerous(text: str):
89
  blocked = ["delete", "drop", "truncate", "remove"]
90
  return any(b in text.lower() for b in blocked)
91
 
92
+ # ==================================================
93
  # AGENT (INTENT β†’ FUNCTION CALL)
94
+ # ==================================================
95
+ def agent(user_input: str, client: InferenceClient):
96
+ logger.info("User input: %s", user_input)
97
+
98
  if is_dangerous(user_input):
99
+ logger.warning("Blocked dangerous operation")
100
  return "❌ Dangerous operation is not allowed."
101
 
102
  text = user_input.lower()
103
 
104
+ # ---- FUNCTION CALLS ----
105
  if "stats" in text or "summary" in text:
106
+ s = tool_get_stats()
107
  return (
108
  f"πŸ“Š **Business Overview**\n"
109
+ f"- Rows: {s['rows']}\n"
110
+ f"- Avg Price: ${s['avg_price']}\n"
111
+ f"- Avg Delivery Days: {s['avg_delivery_days']}"
112
  )
113
 
114
  if text.startswith("show"):
115
  product = user_input.replace("show", "").strip()
116
+ result = tool_query_product(product)
117
+
118
+ if not result:
119
+ return (
120
+ "No data found for this product.\n\n"
121
+ "Would you like me to create a support ticket?"
122
+ )
123
+ return result
124
 
125
  if "support" in text or "ticket" in text:
126
+ t = tool_create_support_ticket(user_input)
127
  return (
128
  f"🎫 **Support Ticket Created**\n"
129
+ f"- ID: {t['ticket_id']}\n"
130
+ f"- System: {t['system']}\n"
131
+ f"- Status: {t['status']}"
132
  )
133
 
134
+ # ---- LLM FALLBACK (NO RAW DATA PASSED) ----
135
  prompt = f"""
136
  You are an AI procurement assistant.
137
+ You do NOT have access to raw data.
138
+ If the user asks something unclear, suggest contacting support.
139
 
140
  User: {user_input}
141
  Assistant:
142
  """
143
  logger.info("LLM called")
144
+ return client.text_generation(
145
+ prompt,
146
+ max_new_tokens=150,
147
+ temperature=0.3
148
+ )
149
 
150
+ # ==================================================
151
  # STREAMLIT UI
152
+ # ==================================================
153
  st.set_page_config(page_title="AI Procurement Agent", layout="wide")
154
+ st.title("πŸ€– AI Procurement Agent (Single-file MVP)")
155
 
156
+ # ---- Sidebar: Business Info ----
157
  stats = tool_get_stats()
158
+ st.sidebar.header("πŸ“Š Business Info")
159
  st.sidebar.metric("Rows", stats["rows"])
160
  st.sidebar.metric("Avg Price", f"${stats['avg_price']}")
161
  st.sidebar.metric("Avg Delivery Days", stats["avg_delivery_days"])
 
167
  create support ticket
168
  """)
169
 
170
+ # ---- HF Token ----
171
+ hf_token = st.sidebar.text_input(
172
+ "Hugging Face Token",
173
+ type="password",
174
+ help="https://huggingface.co/settings/tokens"
175
+ )
176
+
177
  if not hf_token:
178
+ st.warning("Please provide HF token")
179
  st.stop()
180
 
181
  client = InferenceClient(
 
183
  token=hf_token
184
  )
185
 
186
+ # ---- Chat ----
187
  if "messages" not in st.session_state:
188
  st.session_state.messages = []
189
 
190
+ user_input = st.chat_input("Ask the procurement agent...")
191
 
192
  if user_input:
193
  st.session_state.messages.append(("user", user_input))