Spaces:
Sleeping
Sleeping
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +40 -17
src/streamlit_app.py
CHANGED
|
@@ -16,10 +16,6 @@ logger = logging.getLogger("HF-AI-Agent")
|
|
| 16 |
# CONFIG
|
| 17 |
# ======================
|
| 18 |
HF_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
|
| 19 |
-
HF_TOKEN = os.getenv("HF_TOKEN") # HF Spaces Secret
|
| 20 |
-
|
| 21 |
-
if HF_TOKEN is None:
|
| 22 |
-
raise RuntimeError("HF_TOKEN secret is missing")
|
| 23 |
|
| 24 |
# ======================
|
| 25 |
# LOAD DATA (600+ rows)
|
|
@@ -49,7 +45,7 @@ def load_data():
|
|
| 49 |
df = load_data()
|
| 50 |
|
| 51 |
# ======================
|
| 52 |
-
# TOOLS
|
| 53 |
# ======================
|
| 54 |
def query_database(product: str):
|
| 55 |
logger.info("Tool call: query_database(%s)", product)
|
|
@@ -74,17 +70,44 @@ def create_support_ticket(text: str):
|
|
| 74 |
}
|
| 75 |
|
| 76 |
# ======================
|
| 77 |
-
# SAFETY
|
| 78 |
# ======================
|
| 79 |
def is_dangerous(text: str):
|
| 80 |
blocked = ["delete", "drop", "truncate", "remove table"]
|
| 81 |
return any(b in text.lower() for b in blocked)
|
| 82 |
|
| 83 |
# ======================
|
| 84 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
# ======================
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
def agent(user_input: str):
|
| 89 |
if is_dangerous(user_input):
|
| 90 |
logger.warning("Blocked unsafe operation")
|
|
@@ -124,17 +147,15 @@ User: {user_input}
|
|
| 124 |
Assistant:
|
| 125 |
"""
|
| 126 |
logger.info("HF model called")
|
| 127 |
-
return client.text_generation(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
# ======================
|
| 130 |
-
#
|
| 131 |
# ======================
|
| 132 |
-
st.set_page_config(page_title="AI Procurement Agent (MVP)", layout="wide")
|
| 133 |
-
st.title("π€ AI Procurement Agent β MVP (HF Spaces)")
|
| 134 |
-
|
| 135 |
-
st.caption("Minimal, secure, data-aware AI agent demo")
|
| 136 |
-
|
| 137 |
-
# Sidebar (Business Info)
|
| 138 |
st.sidebar.header("π Business Overview")
|
| 139 |
stats = get_aggregates()
|
| 140 |
st.sidebar.metric("Rows", stats["rows"])
|
|
@@ -148,7 +169,9 @@ database stats
|
|
| 148 |
create support ticket
|
| 149 |
""")
|
| 150 |
|
| 151 |
-
#
|
|
|
|
|
|
|
| 152 |
if "messages" not in st.session_state:
|
| 153 |
st.session_state.messages = []
|
| 154 |
|
|
|
|
| 16 |
# CONFIG
|
| 17 |
# ======================
|
| 18 |
HF_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
# ======================
|
| 21 |
# LOAD DATA (600+ rows)
|
|
|
|
| 45 |
df = load_data()
|
| 46 |
|
| 47 |
# ======================
|
| 48 |
+
# TOOLS
|
| 49 |
# ======================
|
| 50 |
def query_database(product: str):
|
| 51 |
logger.info("Tool call: query_database(%s)", product)
|
|
|
|
| 70 |
}
|
| 71 |
|
| 72 |
# ======================
|
| 73 |
+
# SAFETY
|
| 74 |
# ======================
|
| 75 |
def is_dangerous(text: str):
|
| 76 |
blocked = ["delete", "drop", "truncate", "remove table"]
|
| 77 |
return any(b in text.lower() for b in blocked)
|
| 78 |
|
| 79 |
# ======================
|
| 80 |
+
# STREAMLIT UI
|
| 81 |
+
# ======================
|
| 82 |
+
st.set_page_config(page_title="AI Procurement Agent (MVP)", layout="wide")
|
| 83 |
+
st.title("π€ AI Procurement Agent β MVP")
|
| 84 |
+
|
| 85 |
+
st.caption("Hugging Face powered, data-aware, safe AI agent demo")
|
| 86 |
+
|
| 87 |
+
# ======================
|
| 88 |
+
# TOKEN HANDLING (IMPORTANT FIX)
|
| 89 |
# ======================
|
| 90 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 91 |
+
|
| 92 |
+
if not hf_token:
|
| 93 |
+
hf_token = st.sidebar.text_input(
|
| 94 |
+
"π Hugging Face API Token",
|
| 95 |
+
type="password",
|
| 96 |
+
help="Create a token at https://huggingface.co/settings/tokens"
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
if not hf_token:
|
| 100 |
+
st.warning("Please provide a Hugging Face API token to continue.")
|
| 101 |
+
st.stop()
|
| 102 |
|
| 103 |
+
client = InferenceClient(
|
| 104 |
+
model=HF_MODEL,
|
| 105 |
+
token=hf_token
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# ======================
|
| 109 |
+
# AGENT LOGIC
|
| 110 |
+
# ======================
|
| 111 |
def agent(user_input: str):
|
| 112 |
if is_dangerous(user_input):
|
| 113 |
logger.warning("Blocked unsafe operation")
|
|
|
|
| 147 |
Assistant:
|
| 148 |
"""
|
| 149 |
logger.info("HF model called")
|
| 150 |
+
return client.text_generation(
|
| 151 |
+
prompt,
|
| 152 |
+
max_new_tokens=200,
|
| 153 |
+
temperature=0.3
|
| 154 |
+
)
|
| 155 |
|
| 156 |
# ======================
|
| 157 |
+
# SIDEBAR β BUSINESS INFO
|
| 158 |
# ======================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
st.sidebar.header("π Business Overview")
|
| 160 |
stats = get_aggregates()
|
| 161 |
st.sidebar.metric("Rows", stats["rows"])
|
|
|
|
| 169 |
create support ticket
|
| 170 |
""")
|
| 171 |
|
| 172 |
+
# ======================
|
| 173 |
+
# CHAT
|
| 174 |
+
# ======================
|
| 175 |
if "messages" not in st.session_state:
|
| 176 |
st.session_state.messages = []
|
| 177 |
|