Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import sqlite3 | |
| import pandas as pd | |
| from datetime import datetime, timedelta | |
| import gradio as gr | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| from faker import Faker | |
| import chromadb | |
| from groq import Groq | |
| # ββ Setup & Initialization βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| GROQ_API_KEY = os.environ.get("GROQ_API_KEY") | |
| if not GROQ_API_KEY: | |
| raise ValueError("GROQ_API_KEY secret is missing!") | |
| client = Groq(api_key=GROQ_API_KEY) | |
| fake = Faker() | |
| MODEL = "llama-3.3-70b-versatile" | |
| DB_PATH = "business.db" | |
| # ββ Data Generation (Runs on Startup) ββββββββββββββββββββββββββββββββββββββββ | |
| def build_database(): | |
| conn = sqlite3.connect(DB_PATH) | |
| cur = conn.cursor() | |
| cur.executescript(""" | |
| DROP TABLE IF EXISTS orders; | |
| DROP TABLE IF EXISTS customers; | |
| DROP TABLE IF EXISTS products; | |
| CREATE TABLE customers ( | |
| id INTEGER PRIMARY KEY, | |
| name TEXT, | |
| segment TEXT, | |
| region TEXT, | |
| churn_date TEXT | |
| ); | |
| CREATE TABLE products ( | |
| id INTEGER PRIMARY KEY, | |
| name TEXT, | |
| category TEXT, | |
| margin_pct REAL | |
| ); | |
| CREATE TABLE orders ( | |
| id INTEGER PRIMARY KEY, | |
| customer_id INTEGER REFERENCES customers(id), | |
| product_id INTEGER REFERENCES products(id), | |
| order_date TEXT, | |
| amount REAL, | |
| quantity INTEGER | |
| ); | |
| """) | |
| # Seed customers | |
| segments = ["Enterprise", "SMB", "Consumer"] | |
| regions = ["North", "South", "East", "West"] | |
| customers = [] | |
| for i in range(1, 201): | |
| churn = None | |
| if i % 7 == 0: | |
| days_ago = fake.random_int(5, 60) | |
| churn = (datetime.now() - timedelta(days=days_ago)).strftime("%Y-%m-%d") | |
| customers.append((i, fake.company(), fake.random_element(segments), fake.random_element(regions), churn)) | |
| cur.executemany("INSERT INTO customers VALUES (?,?,?,?,?)", customers) | |
| # Seed products | |
| products = [ | |
| (1, "Analytics Pro", "Software", 0.72), | |
| (2, "Data Connector", "Software", 0.65), | |
| (3, "Enterprise Suite","Software", 0.80), | |
| (4, "Support Plan", "Service", 0.45), | |
| (5, "Onboarding Pack", "Service", 0.38), | |
| (6, "Hardware Hub", "Hardware", 0.22), | |
| ] | |
| cur.executemany("INSERT INTO products VALUES (?,?,?,?)", products) | |
| # Seed orders | |
| orders = [] | |
| base_date = datetime.now() | |
| for i in range(1, 1501): | |
| order_date = base_date - timedelta(days=fake.random_int(0, 180)) | |
| orders.append(( | |
| i, | |
| fake.random_int(1, 200), | |
| fake.random_int(1, 6), | |
| order_date.strftime("%Y-%m-%d"), | |
| round(fake.random_int(200, 15000) * 1.0, 2), | |
| fake.random_int(1, 10) | |
| )) | |
| cur.executemany("INSERT INTO orders VALUES (?,?,?,?,?,?)", orders) | |
| conn.commit() | |
| conn.close() | |
| DOCUMENTS = [ | |
| {"id": "doc_001", "title": "Q2 Earnings Call Transcript", "date": "2024-07-15", "text": "CEO Opening: Revenue grew 18% YoY driven by Enterprise segment expansion. However, our SMB retention has been a concern β we saw elevated churn in the South region due to pricing sensitivity post our March price increase. CFO noted that hardware margins remain under pressure from supply chain costs and we are pivoting resources toward software-only offerings which carry 70%+ margins. The Analytics Pro product saw exceptional adoption."}, | |
| {"id": "doc_002", "title": "Q3 Earnings Call Transcript", "date": "2024-10-18", "text": "CEO: Q3 showed a sequential revenue dip of 6% compared to Q2. We attribute this to seasonal slowdown and our decision to exit lower-margin hardware deals. Gross margin actually improved from 58% to 63%. The Enterprise Suite saw record bookings. We are investing heavily in the Data Connector product to compete with emerging competitors. Churn in the Consumer segment increased β support tickets indicated frustration with onboarding complexity."}, | |
| {"id": "doc_003", "title": "Analyst Report: SaaS Sector Benchmarks", "date": "2024-09-01", "text": "Industry-wide, SaaS companies with ARR between $10Mβ$50M are seeing average gross margins of 68β74%. Net Revenue Retention (NRR) benchmarks sit at 110% for best-in-class and 95% for median performers. SMB-focused vendors consistently underperform Enterprise-focused peers on NRR by 12β18 points. Hardware-attached SaaS models are under margin pressure across the board. Analyst consensus: pure software plays with strong onboarding NPS outperform."}, | |
| {"id": "doc_004", "title": "Customer Support Tickets β Churn Analysis", "date": "2024-11-01", "text": "Analysis of 312 support tickets from churned customers (last 90 days): - 41% cited difficulty integrating the Data Connector with existing tools - 28% mentioned price increase in March as primary reason for leaving - 19% reported slow response times from support team - 12% said competitor offered better analytics dashboard UX. Most churned customers were in the SMB segment, South and West regions. Average time-to-churn post price increase: 47 days."}, | |
| {"id": "doc_005", "title": "CEO Internal Memo β 2025 Strategy", "date": "2024-12-01", "text": "To all leadership: Our 2025 focus is three-fold. First, improve SMB retention by launching a flexible pricing tier in Q1. Second, double down on Enterprise Suite β it carries 80% margins and customers expand usage by avg 35% YoY. Third, deprecate the Hardware Hub product line by Q3 2025 to free engineering resources. We expect this pivot to improve blended gross margin to 70%+ by year-end. The Data Connector roadmap will be shared at the all-hands in January."}, | |
| {"id": "doc_006", "title": "Board Presentation β Expansion Markets", "date": "2024-11-15", "text": "We evaluated three expansion geographies: APAC, LATAM, and EMEA. EMEA shows the strongest Enterprise density with 2.3x higher ACV than domestic Enterprise deals. LATAM SMB market is large but NRR data suggests high churn risk similar to our domestic SMB challenges. APAC requires significant localisation investment. Recommendation: prioritise EMEA Enterprise expansion in 2025 with a dedicated 4-person sales pod."} | |
| ] | |
| # Initialize and load resources | |
| build_database() | |
| embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
| reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
| chroma_client = chromadb.Client() # In-memory DB is fine for HF spaces | |
| collection = chroma_client.create_collection("docs") | |
| texts = [d["text"].strip() for d in DOCUMENTS] | |
| collection.add( | |
| ids=[d["id"] for d in DOCUMENTS], | |
| embeddings=embedder.encode(texts).tolist(), | |
| documents=texts, | |
| metadatas=[{"title": d["title"], "date": d["date"]} for d in DOCUMENTS] | |
| ) | |
| DB_SCHEMA = """ | |
| Tables available in the SQL database: | |
| customers(id, name, segment[Enterprise|SMB|Consumer], region[North|South|East|West], churn_date[NULL=active]) | |
| products(id, name, category[Software|Service|Hardware], margin_pct) | |
| orders(id, customer_id, product_id, order_date, amount, quantity) | |
| The database contains: order records, customer profiles, product margins, churn dates. | |
| The database does NOT contain: strategic plans, executive quotes, analyst opinions, | |
| support ticket details, reasons for churn, or qualitative sentiment. | |
| """ | |
| # ββ Core Logic Functions βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def route_query(user_query: str) -> dict: | |
| prompt = f"""You are a query router in a hybrid SQL+RAG analytics system. | |
| {DB_SCHEMA} | |
| The RAG document store contains: earnings call transcripts, analyst reports, | |
| CEO memos, support ticket analysis, and board presentations. | |
| Classify this query and return ONLY valid JSON β no extra text, no markdown fences. | |
| Query: "{user_query}" | |
| Return this exact JSON structure: | |
| {{ | |
| "route": "SQL" | "RAG" | "HYBRID", | |
| "confidence": "high" | "medium" | "low", | |
| "reasoning": "one sentence", | |
| "sql_needed": true | false, | |
| "rag_needed": true | false, | |
| "sql_description": "what to query from DB, or null", | |
| "rag_description": "what docs to retrieve, or null" | |
| }}""" | |
| response = client.chat.completions.create(model=MODEL, messages=[{"role": "user", "content": prompt}], max_tokens=400, temperature=0) | |
| raw = response.choices[0].message.content.strip() | |
| raw = raw.replace("```json", "").replace("```", "").strip() | |
| return json.loads(raw) | |
| def generate_sql(user_query: str, routing_info: dict) -> str: | |
| prompt = f"""You are a SQLite expert. Write a single SELECT query to answer the question. | |
| {DB_SCHEMA} | |
| Rules: | |
| - Write ONLY a SELECT statement | |
| - Use SQLite syntax | |
| - Limit results to 20 rows unless asking for a total/count | |
| - Today's date: {datetime.now().strftime('%Y-%m-%d')} | |
| Question: {user_query} | |
| Routing hint: {routing_info.get('sql_description', '')} | |
| Return ONLY the SQL query, nothing else.""" | |
| response = client.chat.completions.create(model=MODEL, messages=[{"role": "user", "content": prompt}], max_tokens=500, temperature=0) | |
| sql = response.choices[0].message.content.strip() | |
| return sql.replace("```sql", "").replace("```", "").strip() | |
| def execute_sql(sql: str) -> dict: | |
| clean = sql.strip().upper() | |
| if not clean.startswith("SELECT") or any(kw in clean for kw in ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER"]): | |
| return {"error": "Query blocked β only SELECT allowed.", "sql": sql, "rows": None} | |
| try: | |
| conn = sqlite3.connect(DB_PATH) | |
| df = pd.read_sql(sql, conn) | |
| conn.close() | |
| markdown_table = df.to_markdown(index=False) if not df.empty else "(no rows returned)" | |
| return {"sql": sql, "rows": df.to_dict(orient="records"), "markdown": markdown_table, "error": None} | |
| except Exception as e: | |
| return {"error": str(e), "sql": sql, "rows": None} | |
| def sql_path(user_query: str, routing_info: dict) -> dict: | |
| sql = generate_sql(user_query, routing_info) | |
| result = execute_sql(sql) | |
| if result["error"]: | |
| fix_prompt = f"The SQL query below caused an error. Rewrite it.\nOriginal query: {sql}\nError: {result['error']}\n{DB_SCHEMA}\nReturn ONLY the corrected SQL query." | |
| resp = client.chat.completions.create(model=MODEL, messages=[{"role": "user", "content": fix_prompt}], max_tokens=400, temperature=0) | |
| fixed = resp.choices[0].message.content.strip() | |
| fixed = fixed.replace("```sql", "").replace("```", "").strip() | |
| result = execute_sql(fixed) | |
| return result | |
| def retrieve_chunks(user_query: str, top_k_retrieve: int = 6, top_k_final: int = 3) -> list: | |
| query_embedding = embedder.encode([user_query]).tolist() | |
| results = collection.query(query_embeddings=query_embedding, n_results=top_k_retrieve) | |
| candidates = [] | |
| for i in range(len(results["ids"][0])): | |
| candidates.append({ | |
| "id": results["ids"][0][i], "text": results["documents"][0][i], | |
| "title": results["metadatas"][0][i]["title"], "date": results["metadatas"][0][i]["date"], | |
| "vec_score": 1 - results["distances"][0][i] | |
| }) | |
| pairs = [(user_query, c["text"]) for c in candidates] | |
| rerank_scores = reranker.predict(pairs).tolist() | |
| for c, score in zip(candidates, rerank_scores): | |
| c["rerank_score"] = round(score, 4) | |
| ranked = sorted(candidates, key=lambda x: x["rerank_score"], reverse=True) | |
| return ranked[:top_k_final] | |
| def rag_path(user_query: str) -> dict: | |
| chunks = retrieve_chunks(user_query) | |
| formatted = [f"[Source {i+1}: {c['title']} ({c['date']})]\n{c['text'].strip()}" for i, c in enumerate(chunks)] | |
| return {"chunks": chunks, "context_block": "\n\n".join(formatted)} | |
| def synthesise(user_query, route, sql_result, rag_result): | |
| context_parts = [] | |
| if sql_result and sql_result.get("rows") is not None: | |
| context_parts.append(f"## SQL Database Results\nQuery executed: `{sql_result['sql']}`\n\n{sql_result['markdown']}") | |
| elif sql_result and sql_result.get("error"): | |
| context_parts.append(f"## SQL Database Results\nCould not retrieve SQL data: {sql_result['error']}") | |
| if rag_result and rag_result.get("context_block"): | |
| context_parts.append(f"## Retrieved Documents\n{rag_result['context_block']}") | |
| context = "\n\n---\n\n".join(context_parts) | |
| system_prompt = """You are a business intelligence assistant. Answer using ONLY the provided context. | |
| Rules: | |
| - Cite SQL data as "internal database" | |
| - Cite documents by their source title in brackets | |
| - If SQL numbers and documents contradict, surface both | |
| - If context is insufficient, say so clearly | |
| - Be concise and structured""" | |
| response = client.chat.completions.create( | |
| model=MODEL, | |
| messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": f"User question: {user_query}\n\nContext:\n{context}"}], | |
| max_tokens=800, temperature=0.2 | |
| ) | |
| return response.choices[0].message.content.strip() | |
| # ββ UI Integration βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def process_query(user_query: str): | |
| trace_log = [] | |
| routing = route_query(user_query) | |
| route = routing["route"] | |
| trace_log.append(f"**Route Selected:** {route} (Confidence: {routing['confidence']})") | |
| trace_log.append(f"**Reasoning:** {routing['reasoning']}") | |
| sql_result, rag_result = None, None | |
| if routing["sql_needed"]: | |
| sql_result = sql_path(user_query, routing) | |
| if sql_result.get("rows") is not None: | |
| trace_log.append(f"**SQL Executed:**\n") | |
| else: | |
| trace_log.append(f"**SQL Error:** {sql_result['error']}") | |
| if routing["rag_needed"]: | |
| rag_result = rag_path(user_query) | |
| trace_log.append("**Documents Retrieved:**") | |
| for c in rag_result["chunks"]: | |
| trace_log.append(f"- [{c['rerank_score']:+.3f}] {c['title']}") | |
| answer = synthesise(user_query, route, sql_result, rag_result) | |
| return answer, "\n\n".join(trace_log) | |
| # Build Gradio UI | |
| with gr.Blocks(theme=gr.themes.Soft()) as app: | |
| gr.Markdown("# π Text-to-SQL + RAG Hybrid Business Assistant") | |
| gr.Markdown("Ask questions about internal metrics (database) or strategy/qualitative notes (documents).") | |
| with gr.Row(): | |
| query_input = gr.Textbox(label="Your Question", placeholder="e.g. Which customers churned this quarter and why?") | |
| btn = gr.Button("Ask Assistant", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| answer_output = gr.Markdown(label="Answer") | |
| with gr.Column(scale=1): | |
| with gr.Accordion("Agent Trace Log (Under the Hood)", open=False): | |
| trace_output = gr.Markdown() | |
| btn.click(fn=process_query, inputs=query_input, outputs=[answer_output, trace_output]) | |
| query_input.submit(fn=process_query, inputs=query_input, outputs=[answer_output, trace_output]) | |
| gr.Examples([ | |
| "What is our total revenue by product category this month?", | |
| "What did the CEO say about the 2025 strategy and product priorities?", | |
| "Which customers churned this quarter and why are they leaving?", | |
| "How do our product margins compare to industry benchmarks?" | |
| ], inputs=query_input) | |
| if __name__ == "__main__": | |
| app.launch() |