Spaces:
Sleeping
Sleeping
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +63 -50
src/streamlit_app.py
CHANGED
|
@@ -7,14 +7,17 @@ from huggingface_hub import InferenceClient
|
|
| 7 |
# ==================================================
|
| 8 |
# LOGGING
|
| 9 |
# ==================================================
|
| 10 |
-
logging.basicConfig(
|
|
|
|
|
|
|
|
|
|
| 11 |
logger = logging.getLogger("AI-Agent")
|
| 12 |
|
| 13 |
# ==================================================
|
| 14 |
# RAW DATA (IN-MEMORY, 600+ ROWS)
|
| 15 |
# ==================================================
|
| 16 |
@st.cache_data
|
| 17 |
-
def generate_data(rows=600):
|
| 18 |
random.seed(42)
|
| 19 |
data = []
|
| 20 |
|
|
@@ -35,7 +38,7 @@ def generate_data(rows=600):
|
|
| 35 |
)
|
| 36 |
})
|
| 37 |
|
| 38 |
-
logger.info("
|
| 39 |
return data
|
| 40 |
|
| 41 |
DATA = generate_data()
|
|
@@ -44,7 +47,7 @@ DATA = generate_data()
|
|
| 44 |
# TOOLS (FUNCTION CALLING)
|
| 45 |
# ==================================================
|
| 46 |
def tool_get_stats():
|
| 47 |
-
logger.info("Tool
|
| 48 |
|
| 49 |
prices = [x["price"] for x in DATA]
|
| 50 |
days = [x["delivery_days"] for x in DATA]
|
|
@@ -56,7 +59,7 @@ def tool_get_stats():
|
|
| 56 |
}
|
| 57 |
|
| 58 |
def tool_query_product(product: str):
|
| 59 |
-
logger.info("Tool
|
| 60 |
|
| 61 |
return [
|
| 62 |
{
|
|
@@ -73,7 +76,7 @@ def tool_query_product(product: str):
|
|
| 73 |
|
| 74 |
def tool_create_support_ticket(text: str):
|
| 75 |
ticket_id = str(uuid.uuid4())[:8]
|
| 76 |
-
logger.info("
|
| 77 |
|
| 78 |
return {
|
| 79 |
"ticket_id": ticket_id,
|
|
@@ -85,7 +88,7 @@ def tool_create_support_ticket(text: str):
|
|
| 85 |
# ==================================================
|
| 86 |
# SAFETY
|
| 87 |
# ==================================================
|
| 88 |
-
def is_dangerous(text: str):
|
| 89 |
blocked = ["delete", "drop", "truncate", "remove"]
|
| 90 |
return any(b in text.lower() for b in blocked)
|
| 91 |
|
|
@@ -97,85 +100,95 @@ def agent(user_input: str, client: InferenceClient):
|
|
| 97 |
|
| 98 |
if is_dangerous(user_input):
|
| 99 |
logger.warning("Blocked dangerous operation")
|
| 100 |
-
return "β Dangerous
|
| 101 |
|
| 102 |
text = user_input.lower()
|
| 103 |
|
| 104 |
-
# ---- FUNCTION CALLS ----
|
| 105 |
if "stats" in text or "summary" in text:
|
| 106 |
-
|
| 107 |
return (
|
| 108 |
f"π **Business Overview**\n"
|
| 109 |
-
f"- Rows: {
|
| 110 |
-
f"- Avg Price: ${
|
| 111 |
-
f"- Avg Delivery Days: {
|
| 112 |
)
|
| 113 |
|
| 114 |
if text.startswith("show"):
|
| 115 |
product = user_input.replace("show", "").strip()
|
| 116 |
-
|
| 117 |
|
| 118 |
-
if not
|
| 119 |
return (
|
| 120 |
-
"No
|
| 121 |
"Would you like me to create a support ticket?"
|
| 122 |
)
|
| 123 |
-
return
|
| 124 |
|
| 125 |
if "support" in text or "ticket" in text:
|
| 126 |
-
|
| 127 |
return (
|
| 128 |
f"π« **Support Ticket Created**\n"
|
| 129 |
-
f"- ID: {
|
| 130 |
-
f"- System: {
|
| 131 |
-
f"- Status: {
|
| 132 |
)
|
| 133 |
|
| 134 |
-
# ---- LLM FALLBACK (
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
""
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
temperature=0.3
|
| 148 |
)
|
| 149 |
|
|
|
|
|
|
|
| 150 |
# ==================================================
|
| 151 |
# STREAMLIT UI
|
| 152 |
# ==================================================
|
| 153 |
st.set_page_config(page_title="AI Procurement Agent", layout="wide")
|
| 154 |
-
st.title("π€ AI Procurement Agent
|
| 155 |
|
| 156 |
-
# ---- Sidebar: Business Info ----
|
| 157 |
stats = tool_get_stats()
|
| 158 |
-
|
|
|
|
| 159 |
st.sidebar.metric("Rows", stats["rows"])
|
| 160 |
st.sidebar.metric("Avg Price", f"${stats['avg_price']}")
|
| 161 |
st.sidebar.metric("Avg Delivery Days", stats["avg_delivery_days"])
|
| 162 |
|
| 163 |
st.sidebar.markdown("### Sample queries")
|
| 164 |
-
st.sidebar.code(
|
| 165 |
-
show tomato
|
| 166 |
-
database stats
|
| 167 |
-
create support ticket
|
| 168 |
-
|
| 169 |
|
| 170 |
-
# ---- HF
|
| 171 |
hf_token = st.sidebar.text_input(
|
| 172 |
"Hugging Face Token",
|
| 173 |
type="password",
|
| 174 |
-
help="https://huggingface.co/settings/tokens"
|
| 175 |
)
|
| 176 |
|
| 177 |
if not hf_token:
|
| 178 |
-
st.warning("Please provide
|
| 179 |
st.stop()
|
| 180 |
|
| 181 |
client = InferenceClient(
|
|
@@ -183,7 +196,7 @@ client = InferenceClient(
|
|
| 183 |
token=hf_token
|
| 184 |
)
|
| 185 |
|
| 186 |
-
# ---- Chat ----
|
| 187 |
if "messages" not in st.session_state:
|
| 188 |
st.session_state.messages = []
|
| 189 |
|
|
@@ -191,8 +204,8 @@ user_input = st.chat_input("Ask the procurement agent...")
|
|
| 191 |
|
| 192 |
if user_input:
|
| 193 |
st.session_state.messages.append(("user", user_input))
|
| 194 |
-
|
| 195 |
-
st.session_state.messages.append(("assistant",
|
| 196 |
|
| 197 |
-
for role,
|
| 198 |
-
st.chat_message(role).write(
|
|
|
|
| 7 |
# ==================================================
|
| 8 |
# LOGGING
|
| 9 |
# ==================================================
|
| 10 |
+
logging.basicConfig(
|
| 11 |
+
level=logging.INFO,
|
| 12 |
+
format="%(asctime)s | %(levelname)s | %(message)s"
|
| 13 |
+
)
|
| 14 |
logger = logging.getLogger("AI-Agent")
|
| 15 |
|
| 16 |
# ==================================================
|
| 17 |
# RAW DATA (IN-MEMORY, 600+ ROWS)
|
| 18 |
# ==================================================
|
| 19 |
@st.cache_data
|
| 20 |
+
def generate_data(rows: int = 600):
|
| 21 |
random.seed(42)
|
| 22 |
data = []
|
| 23 |
|
|
|
|
| 38 |
)
|
| 39 |
})
|
| 40 |
|
| 41 |
+
logger.info("Generated raw dataset with %d rows", len(data))
|
| 42 |
return data
|
| 43 |
|
| 44 |
DATA = generate_data()
|
|
|
|
| 47 |
# TOOLS (FUNCTION CALLING)
|
| 48 |
# ==================================================
|
| 49 |
def tool_get_stats():
|
| 50 |
+
logger.info("Tool call β get_stats")
|
| 51 |
|
| 52 |
prices = [x["price"] for x in DATA]
|
| 53 |
days = [x["delivery_days"] for x in DATA]
|
|
|
|
| 59 |
}
|
| 60 |
|
| 61 |
def tool_query_product(product: str):
|
| 62 |
+
logger.info("Tool call β query_product(%s)", product)
|
| 63 |
|
| 64 |
return [
|
| 65 |
{
|
|
|
|
| 76 |
|
| 77 |
def tool_create_support_ticket(text: str):
|
| 78 |
ticket_id = str(uuid.uuid4())[:8]
|
| 79 |
+
logger.info("Tool call β create_support_ticket (%s)", ticket_id)
|
| 80 |
|
| 81 |
return {
|
| 82 |
"ticket_id": ticket_id,
|
|
|
|
| 88 |
# ==================================================
|
| 89 |
# SAFETY
|
| 90 |
# ==================================================
|
| 91 |
+
def is_dangerous(text: str) -> bool:
|
| 92 |
blocked = ["delete", "drop", "truncate", "remove"]
|
| 93 |
return any(b in text.lower() for b in blocked)
|
| 94 |
|
|
|
|
| 100 |
|
| 101 |
if is_dangerous(user_input):
|
| 102 |
logger.warning("Blocked dangerous operation")
|
| 103 |
+
return "β Dangerous operations are not allowed."
|
| 104 |
|
| 105 |
text = user_input.lower()
|
| 106 |
|
| 107 |
+
# -------- FUNCTION CALLS --------
|
| 108 |
if "stats" in text or "summary" in text:
|
| 109 |
+
stats = tool_get_stats()
|
| 110 |
return (
|
| 111 |
f"π **Business Overview**\n"
|
| 112 |
+
f"- Rows: {stats['rows']}\n"
|
| 113 |
+
f"- Avg Price: ${stats['avg_price']}\n"
|
| 114 |
+
f"- Avg Delivery Days: {stats['avg_delivery_days']}"
|
| 115 |
)
|
| 116 |
|
| 117 |
if text.startswith("show"):
|
| 118 |
product = user_input.replace("show", "").strip()
|
| 119 |
+
rows = tool_query_product(product)
|
| 120 |
|
| 121 |
+
if not rows:
|
| 122 |
return (
|
| 123 |
+
"No records found for this product.\n\n"
|
| 124 |
"Would you like me to create a support ticket?"
|
| 125 |
)
|
| 126 |
+
return rows
|
| 127 |
|
| 128 |
if "support" in text or "ticket" in text:
|
| 129 |
+
ticket = tool_create_support_ticket(user_input)
|
| 130 |
return (
|
| 131 |
f"π« **Support Ticket Created**\n"
|
| 132 |
+
f"- ID: {ticket['ticket_id']}\n"
|
| 133 |
+
f"- System: {ticket['system']}\n"
|
| 134 |
+
f"- Status: {ticket['status']}"
|
| 135 |
)
|
| 136 |
|
| 137 |
+
# -------- LLM FALLBACK (CHAT MODE) --------
|
| 138 |
+
logger.info("LLM fallback β chat_completion")
|
| 139 |
+
|
| 140 |
+
response = client.chat_completion(
|
| 141 |
+
messages=[
|
| 142 |
+
{
|
| 143 |
+
"role": "system",
|
| 144 |
+
"content": (
|
| 145 |
+
"You are an AI procurement assistant. "
|
| 146 |
+
"You do NOT have access to raw data. "
|
| 147 |
+
"Answer concisely and professionally. "
|
| 148 |
+
"If unsure, suggest creating a support ticket."
|
| 149 |
+
)
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
"role": "user",
|
| 153 |
+
"content": user_input
|
| 154 |
+
}
|
| 155 |
+
],
|
| 156 |
+
max_tokens=150,
|
| 157 |
temperature=0.3
|
| 158 |
)
|
| 159 |
|
| 160 |
+
return response.choices[0].message.content
|
| 161 |
+
|
| 162 |
# ==================================================
|
| 163 |
# STREAMLIT UI
|
| 164 |
# ==================================================
|
| 165 |
st.set_page_config(page_title="AI Procurement Agent", layout="wide")
|
| 166 |
+
st.title("π€ AI Procurement Agent β Single File MVP")
|
| 167 |
|
| 168 |
+
# -------- Sidebar: Business Info --------
|
| 169 |
stats = tool_get_stats()
|
| 170 |
+
|
| 171 |
+
st.sidebar.header("π Business Information")
|
| 172 |
st.sidebar.metric("Rows", stats["rows"])
|
| 173 |
st.sidebar.metric("Avg Price", f"${stats['avg_price']}")
|
| 174 |
st.sidebar.metric("Avg Delivery Days", stats["avg_delivery_days"])
|
| 175 |
|
| 176 |
st.sidebar.markdown("### Sample queries")
|
| 177 |
+
st.sidebar.code(
|
| 178 |
+
"show tomato\n"
|
| 179 |
+
"database stats\n"
|
| 180 |
+
"create support ticket"
|
| 181 |
+
)
|
| 182 |
|
| 183 |
+
# -------- HF TOKEN --------
|
| 184 |
hf_token = st.sidebar.text_input(
|
| 185 |
"Hugging Face Token",
|
| 186 |
type="password",
|
| 187 |
+
help="Create token at https://huggingface.co/settings/tokens"
|
| 188 |
)
|
| 189 |
|
| 190 |
if not hf_token:
|
| 191 |
+
st.warning("Please provide a Hugging Face token.")
|
| 192 |
st.stop()
|
| 193 |
|
| 194 |
client = InferenceClient(
|
|
|
|
| 196 |
token=hf_token
|
| 197 |
)
|
| 198 |
|
| 199 |
+
# -------- Chat --------
|
| 200 |
if "messages" not in st.session_state:
|
| 201 |
st.session_state.messages = []
|
| 202 |
|
|
|
|
| 204 |
|
| 205 |
if user_input:
|
| 206 |
st.session_state.messages.append(("user", user_input))
|
| 207 |
+
reply = agent(user_input, client)
|
| 208 |
+
st.session_state.messages.append(("assistant", reply))
|
| 209 |
|
| 210 |
+
for role, message in st.session_state.messages:
|
| 211 |
+
st.chat_message(role).write(message)
|