Jurabek commited on
Commit
9a541b2
Β·
verified Β·
1 Parent(s): 44c1501

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +40 -17
src/streamlit_app.py CHANGED
@@ -16,10 +16,6 @@ logger = logging.getLogger("HF-AI-Agent")
16
  # CONFIG
17
  # ======================
18
  HF_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
19
- HF_TOKEN = os.getenv("HF_TOKEN") # HF Spaces Secret
20
-
21
- if HF_TOKEN is None:
22
- raise RuntimeError("HF_TOKEN secret is missing")
23
 
24
  # ======================
25
  # LOAD DATA (600+ rows)
@@ -49,7 +45,7 @@ def load_data():
49
  df = load_data()
50
 
51
  # ======================
52
- # TOOLS (FUNCTION CALLS)
53
  # ======================
54
  def query_database(product: str):
55
  logger.info("Tool call: query_database(%s)", product)
@@ -74,17 +70,44 @@ def create_support_ticket(text: str):
74
  }
75
 
76
  # ======================
77
- # SAFETY GUARD
78
  # ======================
79
  def is_dangerous(text: str):
80
  blocked = ["delete", "drop", "truncate", "remove table"]
81
  return any(b in text.lower() for b in blocked)
82
 
83
  # ======================
84
- # AGENT LOGIC
 
 
 
 
 
 
 
 
85
  # ======================
86
- client = InferenceClient(model=HF_MODEL, token=HF_TOKEN)
 
 
 
 
 
 
 
 
 
 
 
87
 
 
 
 
 
 
 
 
 
88
  def agent(user_input: str):
89
  if is_dangerous(user_input):
90
  logger.warning("Blocked unsafe operation")
@@ -124,17 +147,15 @@ User: {user_input}
124
  Assistant:
125
  """
126
  logger.info("HF model called")
127
- return client.text_generation(prompt, max_new_tokens=200, temperature=0.3)
 
 
 
 
128
 
129
  # ======================
130
- # STREAMLIT UI
131
  # ======================
132
- st.set_page_config(page_title="AI Procurement Agent (MVP)", layout="wide")
133
- st.title("πŸ€– AI Procurement Agent β€” MVP (HF Spaces)")
134
-
135
- st.caption("Minimal, secure, data-aware AI agent demo")
136
-
137
- # Sidebar (Business Info)
138
  st.sidebar.header("πŸ“Š Business Overview")
139
  stats = get_aggregates()
140
  st.sidebar.metric("Rows", stats["rows"])
@@ -148,7 +169,9 @@ database stats
148
  create support ticket
149
  """)
150
 
151
- # Chat
 
 
152
  if "messages" not in st.session_state:
153
  st.session_state.messages = []
154
 
 
16
  # CONFIG
17
  # ======================
18
  HF_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
 
 
 
 
19
 
20
  # ======================
21
  # LOAD DATA (600+ rows)
 
45
  df = load_data()
46
 
47
  # ======================
48
+ # TOOLS
49
  # ======================
50
  def query_database(product: str):
51
  logger.info("Tool call: query_database(%s)", product)
 
70
  }
71
 
72
  # ======================
73
+ # SAFETY
74
  # ======================
75
  def is_dangerous(text: str):
76
  blocked = ["delete", "drop", "truncate", "remove table"]
77
  return any(b in text.lower() for b in blocked)
78
 
79
  # ======================
80
+ # STREAMLIT UI
81
+ # ======================
82
+ st.set_page_config(page_title="AI Procurement Agent (MVP)", layout="wide")
83
+ st.title("πŸ€– AI Procurement Agent β€” MVP")
84
+
85
+ st.caption("Hugging Face powered, data-aware, safe AI agent demo")
86
+
87
+ # ======================
88
+ # TOKEN HANDLING (IMPORTANT FIX)
89
  # ======================
90
+ hf_token = os.getenv("HF_TOKEN")
91
+
92
+ if not hf_token:
93
+ hf_token = st.sidebar.text_input(
94
+ "πŸ”‘ Hugging Face API Token",
95
+ type="password",
96
+ help="Create a token at https://huggingface.co/settings/tokens"
97
+ )
98
+
99
+ if not hf_token:
100
+ st.warning("Please provide a Hugging Face API token to continue.")
101
+ st.stop()
102
 
103
+ client = InferenceClient(
104
+ model=HF_MODEL,
105
+ token=hf_token
106
+ )
107
+
108
+ # ======================
109
+ # AGENT LOGIC
110
+ # ======================
111
  def agent(user_input: str):
112
  if is_dangerous(user_input):
113
  logger.warning("Blocked unsafe operation")
 
147
  Assistant:
148
  """
149
  logger.info("HF model called")
150
+ return client.text_generation(
151
+ prompt,
152
+ max_new_tokens=200,
153
+ temperature=0.3
154
+ )
155
 
156
  # ======================
157
+ # SIDEBAR – BUSINESS INFO
158
  # ======================
 
 
 
 
 
 
159
  st.sidebar.header("πŸ“Š Business Overview")
160
  stats = get_aggregates()
161
  st.sidebar.metric("Rows", stats["rows"])
 
169
  create support ticket
170
  """)
171
 
172
+ # ======================
173
+ # CHAT
174
+ # ======================
175
  if "messages" not in st.session_state:
176
  st.session_state.messages = []
177