0ndr3's picture
Update app.py
e1cc806 verified
# IMPORTS
import warnings
warnings.filterwarnings("ignore", message="Failed to load HostKeys")
warnings.filterwarnings("ignore", message="The 'tuples' format for chatbot messages is deprecated")
warnings.filterwarnings("ignore", category=DeprecationWarning)
import os, re, json, pandas as pd, pysftp
from dateutil import parser
import gradio as gr
from langchain_core.documents import Document
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableMap, RunnableLambda
from langchain.memory import ConversationBufferMemory
from langchain_groq import ChatGroq
# SECRETS & PATHS
SFTP_HOST = os.getenv("SFTP_HOST")
SFTP_USER = os.getenv("SFTP_USER")
SFTP_PASSWORD = os.getenv("SFTP_PASSWORD")
SFTP_ALERTS_DIR = "/home/birkbeck/alerts"
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
HISTORICAL_JSON = "data/big_prize_data.json"
# CHAT MEMORY
memory = ConversationBufferMemory(memory_key="chat_history", input_key="question")
# BUILD HISTORICAL CHROMA DB
def build_chroma_db():
with open(HISTORICAL_JSON) as f:
raw = json.load(f)
docs = []
for d in raw:
content = (
f"Raffle {d['raffle']} | Prize: {d['prize']} | Value: Β£{d['value']}"
f"Single ticket price: Β£{d['ticket_price']}"
f"Won at tickets sold:"
f"qty: {d['approx_tickets_sold']}"
f"pct: {d['percent_path']}%"
f"val: Β£{d['approx_value_at_win']}"
)
md = {
"raffle": d["raffle"],
"prize": d["prize"],
"value": float(d["value"]),
"tickets_sold": int(d["approx_tickets_sold"]),
"percent_path": float(d["percent_path"]),
"ticket_price": float(d["ticket_price"]),
"approx_value_at_win": float(d["approx_value_at_win"]),
"timestamp": d.get("timestamp", "2000-01-01T00:00:00"),
"source": "historical"
}
docs.append(Document(page_content=content, metadata=md))
emb = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
return Chroma.from_documents(docs, emb)
# LOAD RECENT/LIVE ALERTS VIA SFTP
def load_live_alerts():
cnopts = pysftp.CnOpts()
cnopts.hostkeys = None
alerts = []
with pysftp.Connection(
host=SFTP_HOST, username=SFTP_USER, password=SFTP_PASSWORD, cnopts=cnopts
) as sftp:
sftp.chdir(SFTP_ALERTS_DIR)
for fn in sorted(sftp.listdir()):
if not fn.endswith(".csv"):
continue
df = pd.read_csv(sftp.open(fn))
if len(df) < 2:
continue
row = df.iloc[0]
ticket_price = float(row["ticket_price"])
tickets_sold = int(row["entries_sold"])
pct = float(re.search(r"(\d+\.\d+)", str(df.iloc[1].get("timestamp",""))).group(1))
ts_iso = row["timestamp"]
date = parser.isoparse(ts_iso).strftime("%d/%m/%Y")
num = int(re.search(r"#(\d+)", row["raffle_name"]).group(1))
raw_val = str(row["prize_value"]).replace("Β£","").replace(",","")
val = float(raw_val)
approx_value_at_win = tickets_sold * ticket_price
content = (
f"Raffle {num} | Prize: {row['prize_name']} | Value: Β£{val:.2f}"
f"Single ticket price: Β£{ticket_price:.2f} | Won @ {date}"
f"Won at tickets sold:"
f"qty: {tickets_sold}"
f"pct: {pct:.2f}%"
f"val: Β£{approx_value_at_win:.2f}"
)
md = {
"raffle": row["raffle_name"],
"raffle_number": num,
"prize": row["prize_name"],
"value": val,
"tickets_sold": tickets_sold,
"percent_path": pct,
"ticket_price": ticket_price,
"approx_value_at_win": approx_value_at_win,
"timestamp": ts_iso,
"source": "recent and live"
}
alerts.append(Document(page_content=content, metadata=md))
return alerts
# RETRIEVER
db = build_chroma_db()
live_docs = load_live_alerts()
def combined_docs(q: str):
hist = db.similarity_search(q, k=8)
return hist + live_docs
# PROMPT + FILTER CHAIN
prompt = PromptTemplate(
input_variables=["chat_history","context","question"],
template="""
You are **Rafael The Raffler**, a calm friendly expert in instant-win raffle analysis.
**Only** describe your strengths (raffle timing, value insights, patterns) when the user explicitly asks β€œwhat do you do?” or "what you good at?".
If they merely greet you or ask anything else, do **not** list your strengthsβ€”just answer the question.
Reasoning Rules:
1. **Interpreting β€œWhen”:** Whenever the user asks β€œWhen…?”, interpret that as β€œAt what tickets-sold count and percent did the prize win occur?” Do *not* give calendar dates or times.
--- Conversation So Far ---
{chat_history}
--- Raffle Data ---
{context}
--- Question ---
{question}
"""
)
def filter_docs(inputs):
docs, q = inputs["documents"], inputs["question"].lower()
# RECENT/LIVE
if ("live" in q or "latest" in q or "recent" in q) and any(w in q for w in ("prize","raffle","won")):
live = [d for d in docs if d.metadata["source"]=="recent and live"]
if live:
recent = max(live, key=lambda d: parser.isoparse(d.metadata["timestamp"]))
return {"documents":[recent], "question":q}
# THRESHOLD
m = re.search(r"(?:above|over|greater than)\s*Β£?([\d,]+)", q)
if m:
thr = float(m.group(1).replace(",",""))
docs = [d for d in docs if d.metadata["value"] > thr]
return {"documents":docs, "question":q}
# FOLLOW-UP QUESTION REWRITING
question_rewrite_template = PromptTemplate(
input_variables=["chat_history","question"],
template="""
Rewrite the following user query to be a fully self-contained question, given the conversation so far.
Conversation:
{chat_history}
Follow-up:
{question}
Rewritten standalone question:"""
)
rewrite_chain = (
RunnableLambda(lambda q: {
"chat_history": memory.load_memory_variables({})["chat_history"],
"question": q
})
| RunnableLambda(lambda inp: question_rewrite_template.format(**inp))
| ChatGroq(api_key=GROQ_API_KEY, model="llama3-8b-8192")
| StrOutputParser()
)
# RAG + CHATGROQ CHAIN (WITH REWRITE) ────────────────────────────────────
retrieval_chain = (
# 1. REWRITE QUESTION FIRST
rewrite_chain
# 2. RETRIEVE DOCS AGAINST REWRITTEN QUESTION
| RunnableMap({
"documents": lambda rewritten_q: combined_docs(rewritten_q),
"question": lambda rewritten_q: rewritten_q
})
| RunnableLambda(filter_docs)
# 3. BUILD FINAL INPUTS AND TRUNCATE HISTORY
| RunnableLambda(lambda d: {
"chat_history": "\n".join(
memory.load_memory_variables({})["chat_history"].splitlines()[-4:]
),
"context": "\n".join(doc.page_content for doc in d["documents"]),
"question": d["question"]
})
# 4. FORMAT FINAL PROMPT AND CALL LLM
| RunnableLambda(lambda inp: prompt.format(**inp))
| ChatGroq(api_key=GROQ_API_KEY, model="llama3-8b-8192")
| StrOutputParser()
)
# GRADIO
WELCOME = """
πŸ‘‹ **Welcome to Rafael The Raffler**
Your raffle-analysis assistant with RAG.
Ask about raffle wins, ticket timing, prize values or the latest live raffle.
"""
# GREETING HANDLING
def handle_greeting(question: str):
if re.match(r'^(hi|hello|hey)[.!?]*$', question.strip(), re.I):
return "Hello! How can I help you with your raffle analysis today?"
def gradio_chat(question: str) -> str:
# 1. GREETING ONLY?
greet = handle_greeting(question)
if greet:
# SAVE GREETING
memory.save_context({"question": question}, {"answer": greet})
return greet
# 2. OTHERWISE > RAG CHAIN
answer = retrieval_chain.invoke(question)
memory.save_context({"question": question}, {"answer": answer})
return answer
demo = gr.Interface(
fn=gradio_chat,
inputs=gr.Textbox(lines=1, placeholder="e.g. When was the recent big prize won?"),
outputs="text",
title="Rafael The Raffler",
description=WELCOME,
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)