wiserkhan / app.py
khanhamzawiser's picture
Update app.py
ac86bff verified
raw
history blame
4.51 kB
import gradio as gr
from huggingface_hub import InferenceClient
import psycopg2
import os
import re
# Hugging Face Zephyr model
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
# TimescaleDB config (via Hugging Face Space secrets)
DB_CONFIG = {
"host": os.getenv("DB_HOST"),
"port": os.getenv("DB_PORT", 5432),
"database": os.getenv("DB_NAME"),
"user": os.getenv("DB_USER"),
"password": os.getenv("DB_PASSWORD"),
}
# Query TimescaleDB
def query_timescaledb(sql_query):
try:
with psycopg2.connect(**DB_CONFIG) as conn:
with conn.cursor() as cur:
cur.execute(sql_query)
return cur.fetchall()
except Exception as e:
return f"DB Error: {e}"
# Basic pattern matching to route question types
def get_sql_for_question(message):
message = message.lower()
if "average current" in message:
return """
SELECT AVG(CT_Avg) as avg_current FROM machine_current_log
WHERE created_at >= NOW() - INTERVAL '1 day';
""", "Here's the average current over the past 24 hours:"
elif "total current" in message:
return """
SELECT created_at, total_current FROM machine_current_log
WHERE created_at >= NOW() - INTERVAL '1 day'
ORDER BY created_at DESC LIMIT 10;
""", "Here are the latest 10 total current readings:"
elif "state duration" in message or "longest running state" in message:
return """
SELECT state, MAX(state_duration) FROM machine_current_log
WHERE created_at >= NOW() - INTERVAL '1 week'
GROUP BY state
ORDER BY MAX(state_duration) DESC LIMIT 1;
""", "Here's the longest running machine state this week:"
elif "fault" in message:
return """
SELECT fault_status, COUNT(*) FROM machine_current_log
WHERE fault_status IS NOT NULL
GROUP BY fault_status
ORDER BY COUNT(*) DESC;
""", "Here is the frequency of different fault statuses:"
return None, None
# Respond using LLM + data if relevant
def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p):
sql_query, context_prefix = get_sql_for_question(message)
if sql_query:
result = query_timescaledb(sql_query)
if isinstance(result, str): # error case
db_info = result
elif not result:
db_info = "No data found."
else:
# Clean and format result
db_info = "\n".join(str(row) for row in result)
message = f"{context_prefix}\n{db_info}\n\nAnswer the user's query based on this information."
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response += token
yield response
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## πŸ€– Wiser AI Assistant")
gr.Markdown(
"""
Welcome to **Wiser's AI Assistant**, your smart companion for all things manufacturing.
Ask anything like:
- "What's the average current today?"
- "What faults happened this week?"
- "Tell me the latest machine states"
- "Any machines running too long?"
I'm connected to live TimescaleDB data πŸ‘¨β€πŸ­πŸ“Š
"""
)
gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(
value="You are Wiser, an expert AI assistant in smart manufacturing. Help users understand machine metrics using the latest database values.",
label="System message"
),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
)
# Run
if __name__ == "__main__":
demo.launch()