Jurabek commited on
Commit
9bd7e4f
Β·
verified Β·
1 Parent(s): 9a541b2

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +63 -65
src/streamlit_app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import uuid
3
  import logging
 
4
  import pandas as pd
5
  import numpy as np
6
  import streamlit as st
@@ -18,13 +19,13 @@ logger = logging.getLogger("HF-AI-Agent")
18
  HF_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
19
 
20
  # ======================
21
- # LOAD DATA (600+ rows)
22
  # ======================
23
  @st.cache_data
24
  def load_data():
25
  np.random.seed(42)
26
  rows = 600
27
- df = pd.DataFrame({
28
  "order_id": range(1, rows + 1),
29
  "supplier": np.random.choice(
30
  ["Supplier A", "Supplier B", "Supplier C", "Supplier D"], rows
@@ -39,8 +40,6 @@ def load_data():
39
  ["Delivered", "Pending", "Delayed"], rows
40
  ),
41
  })
42
- logger.info("Dataset loaded (%d rows)", len(df))
43
- return df
44
 
45
  df = load_data()
46
 
@@ -48,11 +47,9 @@ df = load_data()
48
  # TOOLS
49
  # ======================
50
  def query_database(product: str):
51
- logger.info("Tool call: query_database(%s)", product)
52
  return df[df["product"].str.lower() == product.lower()].head(5)
53
 
54
  def get_aggregates():
55
- logger.info("Tool call: get_aggregates()")
56
  return {
57
  "rows": len(df),
58
  "avg_price": round(df["price"].mean(), 2),
@@ -60,10 +57,8 @@ def get_aggregates():
60
  }
61
 
62
  def create_support_ticket(text: str):
63
- ticket_id = str(uuid.uuid4())[:8]
64
- logger.info("Support ticket created: %s", ticket_id)
65
  return {
66
- "ticket_id": ticket_id,
67
  "system": "GitHub Issues (mock)",
68
  "status": "Created",
69
  "description": text,
@@ -77,60 +72,57 @@ def is_dangerous(text: str):
77
  return any(b in text.lower() for b in blocked)
78
 
79
  # ======================
80
- # STREAMLIT UI
81
  # ======================
82
- st.set_page_config(page_title="AI Procurement Agent (MVP)", layout="wide")
83
- st.title("πŸ€– AI Procurement Agent β€” MVP")
84
-
85
- st.caption("Hugging Face powered, data-aware, safe AI agent demo")
86
 
87
- # ======================
88
- # TOKEN HANDLING (IMPORTANT FIX)
89
- # ======================
90
- hf_token = os.getenv("HF_TOKEN")
91
 
92
- if not hf_token:
93
- hf_token = st.sidebar.text_input(
94
- "πŸ”‘ Hugging Face API Token",
95
- type="password",
96
- help="Create a token at https://huggingface.co/settings/tokens"
97
- )
98
 
99
- if not hf_token:
100
- st.warning("Please provide a Hugging Face API token to continue.")
101
- st.stop()
 
 
 
 
102
 
103
- client = InferenceClient(
104
- model=HF_MODEL,
105
- token=hf_token
106
- )
107
 
108
  # ======================
109
- # AGENT LOGIC
110
  # ======================
111
- def agent(user_input: str):
112
- if is_dangerous(user_input):
113
- logger.warning("Blocked unsafe operation")
 
 
114
  return "❌ This operation is not allowed."
115
 
116
- if "stats" in user_input or "summary" in user_input:
117
  s = get_aggregates()
118
  return (
119
  f"πŸ“Š **Database Summary**\n"
120
  f"- Rows: {s['rows']}\n"
121
- f"- Average Price: ${s['avg_price']}\n"
122
  f"- Avg Delivery Days: {s['avg_delivery_days']}"
123
  )
124
 
125
- if user_input.lower().startswith("show"):
126
- product = user_input.replace("show", "").strip()
127
  result = query_database(product)
128
- if result.empty:
129
- return "No records found."
130
- return result
131
 
132
- if "support" in user_input or "ticket" in user_input:
133
- ticket = create_support_ticket(user_input)
134
  return (
135
  f"🎫 **Support Ticket Created**\n"
136
  f"- ID: {ticket['ticket_id']}\n"
@@ -138,15 +130,15 @@ def agent(user_input: str):
138
  f"- Status: {ticket['status']}"
139
  )
140
 
 
141
  prompt = f"""
142
  You are an AI procurement assistant.
143
- You do NOT have direct database access.
144
- Answer concisely and professionally.
145
 
146
- User: {user_input}
147
  Assistant:
148
  """
149
- logger.info("HF model called")
150
  return client.text_generation(
151
  prompt,
152
  max_new_tokens=200,
@@ -154,24 +146,29 @@ Assistant:
154
  )
155
 
156
  # ======================
157
- # SIDEBAR – BUSINESS INFO
158
  # ======================
159
- st.sidebar.header("πŸ“Š Business Overview")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  stats = get_aggregates()
161
  st.sidebar.metric("Rows", stats["rows"])
162
  st.sidebar.metric("Avg Price", f"${stats['avg_price']}")
163
  st.sidebar.metric("Avg Delivery Days", stats["avg_delivery_days"])
164
 
165
- st.sidebar.markdown("### Sample Queries")
166
- st.sidebar.code("""
167
- show tomato
168
- database stats
169
- create support ticket
170
- """)
171
-
172
- # ======================
173
- # CHAT
174
- # ======================
175
  if "messages" not in st.session_state:
176
  st.session_state.messages = []
177
 
@@ -179,11 +176,12 @@ user_input = st.chat_input("Ask the procurement agent...")
179
 
180
  if user_input:
181
  st.session_state.messages.append(("user", user_input))
182
- response = agent(user_input)
 
 
 
183
  st.session_state.messages.append(("assistant", response))
184
 
 
185
  for role, msg in st.session_state.messages:
186
- if role == "user":
187
- st.chat_message("user").write(msg)
188
- else:
189
- st.chat_message("assistant").write(msg)
 
1
  import os
2
  import uuid
3
  import logging
4
+ import json
5
  import pandas as pd
6
  import numpy as np
7
  import streamlit as st
 
19
  HF_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
20
 
21
  # ======================
22
+ # LOAD DATA
23
  # ======================
24
  @st.cache_data
25
  def load_data():
26
  np.random.seed(42)
27
  rows = 600
28
+ return pd.DataFrame({
29
  "order_id": range(1, rows + 1),
30
  "supplier": np.random.choice(
31
  ["Supplier A", "Supplier B", "Supplier C", "Supplier D"], rows
 
40
  ["Delivered", "Pending", "Delayed"], rows
41
  ),
42
  })
 
 
43
 
44
  df = load_data()
45
 
 
47
  # TOOLS
48
  # ======================
49
  def query_database(product: str):
 
50
  return df[df["product"].str.lower() == product.lower()].head(5)
51
 
52
  def get_aggregates():
 
53
  return {
54
  "rows": len(df),
55
  "avg_price": round(df["price"].mean(), 2),
 
57
  }
58
 
59
  def create_support_ticket(text: str):
 
 
60
  return {
61
+ "ticket_id": str(uuid.uuid4())[:8],
62
  "system": "GitHub Issues (mock)",
63
  "status": "Created",
64
  "description": text,
 
72
  return any(b in text.lower() for b in blocked)
73
 
74
  # ======================
75
+ # JSON BUILDER (1-QADAM)
76
  # ======================
77
+ def build_request_json(user_input: str):
78
+ intent = "llm"
 
 
79
 
80
+ text = user_input.lower()
 
 
 
81
 
82
+ if "stats" in text or "summary" in text:
83
+ intent = "stats"
84
+ elif text.startswith("show"):
85
+ intent = "query"
86
+ elif "support" in text or "ticket" in text:
87
+ intent = "support"
88
 
89
+ request_json = {
90
+ "request_id": str(uuid.uuid4()),
91
+ "intent": intent,
92
+ "payload": {
93
+ "text": user_input
94
+ }
95
+ }
96
 
97
+ logger.info("REQUEST JSON: %s", request_json)
98
+ return request_json
 
 
99
 
100
  # ======================
101
+ # ROUTER (2-QADAM)
102
  # ======================
103
+ def router(req: dict, client: InferenceClient):
104
+ intent = req["intent"]
105
+ text = req["payload"]["text"]
106
+
107
+ if is_dangerous(text):
108
  return "❌ This operation is not allowed."
109
 
110
+ if intent == "stats":
111
  s = get_aggregates()
112
  return (
113
  f"πŸ“Š **Database Summary**\n"
114
  f"- Rows: {s['rows']}\n"
115
+ f"- Avg Price: ${s['avg_price']}\n"
116
  f"- Avg Delivery Days: {s['avg_delivery_days']}"
117
  )
118
 
119
+ if intent == "query":
120
+ product = text.replace("show", "").strip()
121
  result = query_database(product)
122
+ return "No records found." if result.empty else result
 
 
123
 
124
+ if intent == "support":
125
+ ticket = create_support_ticket(text)
126
  return (
127
  f"🎫 **Support Ticket Created**\n"
128
  f"- ID: {ticket['ticket_id']}\n"
 
130
  f"- Status: {ticket['status']}"
131
  )
132
 
133
+ # LLM fallback
134
  prompt = f"""
135
  You are an AI procurement assistant.
136
+ You do NOT have database access.
137
+ Answer shortly and professionally.
138
 
139
+ User: {text}
140
  Assistant:
141
  """
 
142
  return client.text_generation(
143
  prompt,
144
  max_new_tokens=200,
 
146
  )
147
 
148
  # ======================
149
+ # STREAMLIT UI
150
  # ======================
151
+ st.set_page_config(page_title="AI Procurement Agent (JSON-based)", layout="wide")
152
+ st.title("πŸ€– AI Procurement Agent β€” JSON Router MVP")
153
+
154
+ hf_token = os.getenv("HF_TOKEN") or st.sidebar.text_input(
155
+ "πŸ”‘ Hugging Face API Token",
156
+ type="password"
157
+ )
158
+
159
+ if not hf_token:
160
+ st.warning("HF token required")
161
+ st.stop()
162
+
163
+ client = InferenceClient(model=HF_MODEL, token=hf_token)
164
+
165
+ # Sidebar
166
  stats = get_aggregates()
167
  st.sidebar.metric("Rows", stats["rows"])
168
  st.sidebar.metric("Avg Price", f"${stats['avg_price']}")
169
  st.sidebar.metric("Avg Delivery Days", stats["avg_delivery_days"])
170
 
171
+ # Chat memory
 
 
 
 
 
 
 
 
 
172
  if "messages" not in st.session_state:
173
  st.session_state.messages = []
174
 
 
176
 
177
  if user_input:
178
  st.session_state.messages.append(("user", user_input))
179
+
180
+ request_json = build_request_json(user_input)
181
+ response = router(request_json, client)
182
+
183
  st.session_state.messages.append(("assistant", response))
184
 
185
+ # Render chat
186
  for role, msg in st.session_state.messages:
187
+ st.chat_message(role).write(msg)