Jurabek commited on
Commit
c3ca079
Β·
verified Β·
1 Parent(s): ac5adde

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +63 -50
src/streamlit_app.py CHANGED
@@ -7,14 +7,17 @@ from huggingface_hub import InferenceClient
7
  # ==================================================
8
  # LOGGING
9
  # ==================================================
10
- logging.basicConfig(level=logging.INFO)
 
 
 
11
  logger = logging.getLogger("AI-Agent")
12
 
13
  # ==================================================
14
  # RAW DATA (IN-MEMORY, 600+ ROWS)
15
  # ==================================================
16
  @st.cache_data
17
- def generate_data(rows=600):
18
  random.seed(42)
19
  data = []
20
 
@@ -35,7 +38,7 @@ def generate_data(rows=600):
35
  )
36
  })
37
 
38
- logger.info("Raw data generated: %d rows", len(data))
39
  return data
40
 
41
  DATA = generate_data()
@@ -44,7 +47,7 @@ DATA = generate_data()
44
  # TOOLS (FUNCTION CALLING)
45
  # ==================================================
46
  def tool_get_stats():
47
- logger.info("Tool called: get_stats")
48
 
49
  prices = [x["price"] for x in DATA]
50
  days = [x["delivery_days"] for x in DATA]
@@ -56,7 +59,7 @@ def tool_get_stats():
56
  }
57
 
58
  def tool_query_product(product: str):
59
- logger.info("Tool called: query_product(%s)", product)
60
 
61
  return [
62
  {
@@ -73,7 +76,7 @@ def tool_query_product(product: str):
73
 
74
  def tool_create_support_ticket(text: str):
75
  ticket_id = str(uuid.uuid4())[:8]
76
- logger.info("Support ticket created: %s", ticket_id)
77
 
78
  return {
79
  "ticket_id": ticket_id,
@@ -85,7 +88,7 @@ def tool_create_support_ticket(text: str):
85
  # ==================================================
86
  # SAFETY
87
  # ==================================================
88
- def is_dangerous(text: str):
89
  blocked = ["delete", "drop", "truncate", "remove"]
90
  return any(b in text.lower() for b in blocked)
91
 
@@ -97,85 +100,95 @@ def agent(user_input: str, client: InferenceClient):
97
 
98
  if is_dangerous(user_input):
99
  logger.warning("Blocked dangerous operation")
100
- return "❌ Dangerous operation is not allowed."
101
 
102
  text = user_input.lower()
103
 
104
- # ---- FUNCTION CALLS ----
105
  if "stats" in text or "summary" in text:
106
- s = tool_get_stats()
107
  return (
108
  f"πŸ“Š **Business Overview**\n"
109
- f"- Rows: {s['rows']}\n"
110
- f"- Avg Price: ${s['avg_price']}\n"
111
- f"- Avg Delivery Days: {s['avg_delivery_days']}"
112
  )
113
 
114
  if text.startswith("show"):
115
  product = user_input.replace("show", "").strip()
116
- result = tool_query_product(product)
117
 
118
- if not result:
119
  return (
120
- "No data found for this product.\n\n"
121
  "Would you like me to create a support ticket?"
122
  )
123
- return result
124
 
125
  if "support" in text or "ticket" in text:
126
- t = tool_create_support_ticket(user_input)
127
  return (
128
  f"🎫 **Support Ticket Created**\n"
129
- f"- ID: {t['ticket_id']}\n"
130
- f"- System: {t['system']}\n"
131
- f"- Status: {t['status']}"
132
  )
133
 
134
- # ---- LLM FALLBACK (NO RAW DATA PASSED) ----
135
- prompt = f"""
136
- You are an AI procurement assistant.
137
- You do NOT have access to raw data.
138
- If the user asks something unclear, suggest contacting support.
139
-
140
- User: {user_input}
141
- Assistant:
142
- """
143
- logger.info("LLM called")
144
- return client.text_generation(
145
- prompt,
146
- max_new_tokens=150,
 
 
 
 
 
 
 
147
  temperature=0.3
148
  )
149
 
 
 
150
  # ==================================================
151
  # STREAMLIT UI
152
  # ==================================================
153
  st.set_page_config(page_title="AI Procurement Agent", layout="wide")
154
- st.title("πŸ€– AI Procurement Agent (Single-file MVP)")
155
 
156
- # ---- Sidebar: Business Info ----
157
  stats = tool_get_stats()
158
- st.sidebar.header("πŸ“Š Business Info")
 
159
  st.sidebar.metric("Rows", stats["rows"])
160
  st.sidebar.metric("Avg Price", f"${stats['avg_price']}")
161
  st.sidebar.metric("Avg Delivery Days", stats["avg_delivery_days"])
162
 
163
  st.sidebar.markdown("### Sample queries")
164
- st.sidebar.code("""
165
- show tomato
166
- database stats
167
- create support ticket
168
- """)
169
 
170
- # ---- HF Token ----
171
  hf_token = st.sidebar.text_input(
172
  "Hugging Face Token",
173
  type="password",
174
- help="https://huggingface.co/settings/tokens"
175
  )
176
 
177
  if not hf_token:
178
- st.warning("Please provide HF token")
179
  st.stop()
180
 
181
  client = InferenceClient(
@@ -183,7 +196,7 @@ client = InferenceClient(
183
  token=hf_token
184
  )
185
 
186
- # ---- Chat ----
187
  if "messages" not in st.session_state:
188
  st.session_state.messages = []
189
 
@@ -191,8 +204,8 @@ user_input = st.chat_input("Ask the procurement agent...")
191
 
192
  if user_input:
193
  st.session_state.messages.append(("user", user_input))
194
- response = agent(user_input, client)
195
- st.session_state.messages.append(("assistant", response))
196
 
197
- for role, msg in st.session_state.messages:
198
- st.chat_message(role).write(msg)
 
7
  # ==================================================
8
  # LOGGING
9
  # ==================================================
10
+ logging.basicConfig(
11
+ level=logging.INFO,
12
+ format="%(asctime)s | %(levelname)s | %(message)s"
13
+ )
14
  logger = logging.getLogger("AI-Agent")
15
 
16
  # ==================================================
17
  # RAW DATA (IN-MEMORY, 600+ ROWS)
18
  # ==================================================
19
  @st.cache_data
20
+ def generate_data(rows: int = 600):
21
  random.seed(42)
22
  data = []
23
 
 
38
  )
39
  })
40
 
41
+ logger.info("Generated raw dataset with %d rows", len(data))
42
  return data
43
 
44
  DATA = generate_data()
 
47
  # TOOLS (FUNCTION CALLING)
48
  # ==================================================
49
  def tool_get_stats():
50
+ logger.info("Tool call β†’ get_stats")
51
 
52
  prices = [x["price"] for x in DATA]
53
  days = [x["delivery_days"] for x in DATA]
 
59
  }
60
 
61
  def tool_query_product(product: str):
62
+ logger.info("Tool call β†’ query_product(%s)", product)
63
 
64
  return [
65
  {
 
76
 
77
  def tool_create_support_ticket(text: str):
78
  ticket_id = str(uuid.uuid4())[:8]
79
+ logger.info("Tool call β†’ create_support_ticket (%s)", ticket_id)
80
 
81
  return {
82
  "ticket_id": ticket_id,
 
88
  # ==================================================
89
  # SAFETY
90
  # ==================================================
91
+ def is_dangerous(text: str) -> bool:
92
  blocked = ["delete", "drop", "truncate", "remove"]
93
  return any(b in text.lower() for b in blocked)
94
 
 
100
 
101
  if is_dangerous(user_input):
102
  logger.warning("Blocked dangerous operation")
103
+ return "❌ Dangerous operations are not allowed."
104
 
105
  text = user_input.lower()
106
 
107
+ # -------- FUNCTION CALLS --------
108
  if "stats" in text or "summary" in text:
109
+ stats = tool_get_stats()
110
  return (
111
  f"πŸ“Š **Business Overview**\n"
112
+ f"- Rows: {stats['rows']}\n"
113
+ f"- Avg Price: ${stats['avg_price']}\n"
114
+ f"- Avg Delivery Days: {stats['avg_delivery_days']}"
115
  )
116
 
117
  if text.startswith("show"):
118
  product = user_input.replace("show", "").strip()
119
+ rows = tool_query_product(product)
120
 
121
+ if not rows:
122
  return (
123
+ "No records found for this product.\n\n"
124
  "Would you like me to create a support ticket?"
125
  )
126
+ return rows
127
 
128
  if "support" in text or "ticket" in text:
129
+ ticket = tool_create_support_ticket(user_input)
130
  return (
131
  f"🎫 **Support Ticket Created**\n"
132
+ f"- ID: {ticket['ticket_id']}\n"
133
+ f"- System: {ticket['system']}\n"
134
+ f"- Status: {ticket['status']}"
135
  )
136
 
137
+ # -------- LLM FALLBACK (CHAT MODE) --------
138
+ logger.info("LLM fallback β†’ chat_completion")
139
+
140
+ response = client.chat_completion(
141
+ messages=[
142
+ {
143
+ "role": "system",
144
+ "content": (
145
+ "You are an AI procurement assistant. "
146
+ "You do NOT have access to raw data. "
147
+ "Answer concisely and professionally. "
148
+ "If unsure, suggest creating a support ticket."
149
+ )
150
+ },
151
+ {
152
+ "role": "user",
153
+ "content": user_input
154
+ }
155
+ ],
156
+ max_tokens=150,
157
  temperature=0.3
158
  )
159
 
160
+ return response.choices[0].message.content
161
+
162
  # ==================================================
163
  # STREAMLIT UI
164
  # ==================================================
165
  st.set_page_config(page_title="AI Procurement Agent", layout="wide")
166
+ st.title("πŸ€– AI Procurement Agent β€” Single File MVP")
167
 
168
+ # -------- Sidebar: Business Info --------
169
  stats = tool_get_stats()
170
+
171
+ st.sidebar.header("πŸ“Š Business Information")
172
  st.sidebar.metric("Rows", stats["rows"])
173
  st.sidebar.metric("Avg Price", f"${stats['avg_price']}")
174
  st.sidebar.metric("Avg Delivery Days", stats["avg_delivery_days"])
175
 
176
  st.sidebar.markdown("### Sample queries")
177
+ st.sidebar.code(
178
+ "show tomato\n"
179
+ "database stats\n"
180
+ "create support ticket"
181
+ )
182
 
183
+ # -------- HF TOKEN --------
184
  hf_token = st.sidebar.text_input(
185
  "Hugging Face Token",
186
  type="password",
187
+ help="Create token at https://huggingface.co/settings/tokens"
188
  )
189
 
190
  if not hf_token:
191
+ st.warning("Please provide a Hugging Face token.")
192
  st.stop()
193
 
194
  client = InferenceClient(
 
196
  token=hf_token
197
  )
198
 
199
+ # -------- Chat --------
200
  if "messages" not in st.session_state:
201
  st.session_state.messages = []
202
 
 
204
 
205
  if user_input:
206
  st.session_state.messages.append(("user", user_input))
207
+ reply = agent(user_input, client)
208
+ st.session_state.messages.append(("assistant", reply))
209
 
210
+ for role, message in st.session_state.messages:
211
+ st.chat_message(role).write(message)