Jurabek commited on
Commit
8ac57a8
Β·
verified Β·
1 Parent(s): 86823d0

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +171 -33
src/streamlit_app.py CHANGED
@@ -1,40 +1,178 @@
1
- import altair as alt
2
- import numpy as np
 
3
  import pandas as pd
 
4
  import streamlit as st
 
5
 
6
- """
7
- # Welcome to Streamlit!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
 
 
 
 
12
 
13
- In the meantime, below is an example of what you can do with just a few lines of code:
 
 
 
 
 
 
14
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import os
2
+ import uuid
3
+ import logging
4
  import pandas as pd
5
+ import numpy as np
6
  import streamlit as st
7
+ from huggingface_hub import InferenceClient
8
 
9
+ # ======================
10
+ # LOGGING
11
+ # ======================
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger("AI-Agent")
14
+
15
+ # ======================
16
+ # CONFIG
17
+ # ======================
18
+ HF_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
19
+
20
+ # ======================
21
+ # LOAD DATA (500+ rows)
22
+ # ======================
23
+ @st.cache_data
24
+ def load_data():
25
+ np.random.seed(42)
26
+ rows = 600
27
+ df = pd.DataFrame({
28
+ "order_id": range(1, rows + 1),
29
+ "supplier": np.random.choice(
30
+ ["Supplier A", "Supplier B", "Supplier C", "Supplier D"], rows
31
+ ),
32
+ "product": np.random.choice(
33
+ ["Tomato", "Cheese", "Flour", "Oil", "Meat"], rows
34
+ ),
35
+ "quantity": np.random.randint(1, 100, rows),
36
+ "price": np.random.uniform(5, 50, rows).round(2),
37
+ "delivery_days": np.random.randint(1, 14, rows),
38
+ "status": np.random.choice(
39
+ ["Delivered", "Pending", "Delayed"], rows
40
+ ),
41
+ })
42
+ logger.info("Loaded dataset with %d rows", len(df))
43
+ return df
44
+
45
+ df = load_data()
46
+
47
+ # ======================
48
+ # TOOLS (FUNCTIONS)
49
+ # ======================
50
+ def query_database(product: str):
51
+ logger.info("Function call: query_database(%s)", product)
52
+ return df[df["product"].str.lower() == product.lower()].head(5)
53
+
54
+ def get_aggregates():
55
+ logger.info("Function call: get_aggregates()")
56
+ return {
57
+ "rows": len(df),
58
+ "avg_price": round(df["price"].mean(), 2),
59
+ "avg_delivery_days": round(df["delivery_days"].mean(), 2),
60
+ }
61
+
62
+ def create_support_ticket(text: str):
63
+ ticket_id = str(uuid.uuid4())[:8]
64
+ logger.info("Support ticket created: %s", ticket_id)
65
+ return {
66
+ "ticket_id": ticket_id,
67
+ "system": "GitHub Issues (mock)",
68
+ "status": "Created",
69
+ "description": text,
70
+ }
71
+
72
+ # ======================
73
+ # SAFETY CHECK
74
+ # ======================
75
+ def safety_guard(user_input: str):
76
+ blocked = ["delete", "drop", "truncate", "remove table"]
77
+ return any(b in user_input.lower() for b in blocked)
78
+
79
+ # ======================
80
+ # HF AGENT LOGIC
81
+ # ======================
82
+ def agent_response(user_input: str, client: InferenceClient):
83
+ if safety_guard(user_input):
84
+ logger.warning("Blocked dangerous request")
85
+ return "❌ This operation is not allowed."
86
+
87
+ # TOOL ROUTING (simple intent routing)
88
+ if "stats" in user_input or "summary" in user_input:
89
+ stats = get_aggregates()
90
+ return (
91
+ f"πŸ“Š Database Summary\n"
92
+ f"- Rows: {stats['rows']}\n"
93
+ f"- Avg Price: ${stats['avg_price']}\n"
94
+ f"- Avg Delivery Days: {stats['avg_delivery_days']}"
95
+ )
96
+
97
+ if user_input.lower().startswith("show"):
98
+ product = user_input.replace("show", "").strip()
99
+ result = query_database(product)
100
+ if result.empty:
101
+ return "No data found for that product."
102
+ return result
103
 
104
+ if "support" in user_input or "ticket" in user_input:
105
+ ticket = create_support_ticket(user_input)
106
+ return (
107
+ f"🎫 Support Ticket Created\n"
108
+ f"- ID: {ticket['ticket_id']}\n"
109
+ f"- System: {ticket['system']}"
110
+ )
111
 
112
+ # HF LLM (TEXT ONLY, NO DB)
113
+ prompt = f"""
114
+ You are a procurement assistant.
115
+ Answer the user without assuming direct database access.
116
+
117
+ User: {user_input}
118
+ Assistant:
119
  """
120
+ logger.info("Calling HF model")
121
+
122
+ response = client.text_generation(
123
+ prompt,
124
+ max_new_tokens=200,
125
+ temperature=0.3,
126
+ )
127
+
128
+ return response
129
+
130
+ # ======================
131
+ # STREAMLIT UI
132
+ # ======================
133
+ st.set_page_config(page_title="AI Procurement Agent (HF)", layout="wide")
134
+ st.title("πŸ€– AI Procurement Agent (Hugging Face Demo)")
135
+
136
+ hf_token = st.sidebar.text_input(
137
+ "πŸ”‘ Hugging Face API Key",
138
+ type="password"
139
+ )
140
+
141
+ # Sidebar – business info
142
+ st.sidebar.header("πŸ“Š Business Overview")
143
+ stats = get_aggregates()
144
+ st.sidebar.metric("Rows in DB", stats["rows"])
145
+ st.sidebar.metric("Avg Price", f"${stats['avg_price']}")
146
+ st.sidebar.metric("Avg Delivery Days", stats["avg_delivery_days"])
147
+
148
+ st.sidebar.markdown("### Sample Queries")
149
+ st.sidebar.code("""
150
+ show tomato
151
+ database stats
152
+ create support ticket
153
+ """)
154
+
155
+ if not hf_token:
156
+ st.warning("Please enter Hugging Face API key to continue.")
157
+ st.stop()
158
+
159
+ client = InferenceClient(
160
+ model=HF_MODEL,
161
+ token=hf_token
162
+ )
163
+
164
+ if "chat" not in st.session_state:
165
+ st.session_state.chat = []
166
+
167
+ user_input = st.chat_input("Ask the agent...")
168
+
169
+ if user_input:
170
+ st.session_state.chat.append(("user", user_input))
171
+ reply = agent_response(user_input, client)
172
+ st.session_state.chat.append(("assistant", reply))
173
 
174
+ for role, msg in st.session_state.chat:
175
+ if role == "user":
176
+ st.chat_message("user").write(msg)
177
+ else:
178
+ st.chat_message("assistant").write(msg)