Spaces:
Sleeping
Sleeping
| """ | |
| tabs/shared_ai.py β Version 3.0 | |
| AI Decision Assistant for Churn Simulation | |
| ------------------------------------------ | |
| β’ Works both locally (Ollama) and on Hugging Face (Hub Inference) | |
| β’ Uses validated JSON schema (Pydantic) | |
| β’ Adds context about available features | |
| β’ Calls dynamic simulation engine (simulate_plan) | |
| β’ Provides user-friendly error feedback | |
| """ | |
| import os | |
| import json | |
| import gradio as gr | |
| from typing import List, Literal, Optional | |
| from pydantic import BaseModel, ValidationError | |
| # LangChain components | |
| from langchain_community.chat_models import ChatOllama, ChatHuggingFace | |
| # from langchain_community.llms import HuggingFaceHub | |
| from langchain_community.chat_models import ChatOllama | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
| # from langchain.prompts import PromptTemplate | |
| # Local utilities | |
| from utils.scenario_engine_ng import simulate_plan | |
| from utils.history import log_simulation | |
| # --------------------------------------------------------------------- | |
| # π§© Pydantic Schemas | |
| # --------------------------------------------------------------------- | |
| class Operation(BaseModel): | |
| op: Literal["scale", "shift", "set", "clip"] | |
| col: str | |
| value: Optional[str] = None | |
| where: Optional[str] = None | |
| min: Optional[float] = None | |
| max: Optional[float] = None | |
| class Plan(BaseModel): | |
| plan: List[Operation] | |
| # --------------------------------------------------------------------- | |
| # βοΈ Environment-Aware LLM Factory | |
| # --------------------------------------------------------------------- | |
| def get_llm(): | |
| """Return an LLM instance depending on environment (local or Hugging Face).""" | |
| if os.getenv("SPACE_ID"): | |
| print("π Detected Hugging Face environment β using HuggingFaceEndpoint model.") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if not HF_TOKEN: | |
| print("β οΈ HF_TOKEN not found β please add it in Space secrets.") | |
| return None | |
| try: | |
| llm = ChatHuggingFace( | |
| llm=HuggingFaceEndpoint( | |
| repo_id="mistralai/Mistral-7B-Instruct-v0.2", | |
| huggingfacehub_api_token=HF_TOKEN, | |
| temperature=0.3, | |
| max_new_tokens=512 | |
| ) | |
| ) | |
| print("β Connected to Hugging Face Endpoint model.") | |
| return llm | |
| except Exception as e: | |
| print(f"β οΈ Failed to connect to Hugging Face Endpoint: {e}") | |
| return None | |
| else: | |
| # local Ollama fallback | |
| try: | |
| llm = ChatOllama(model="mistral") | |
| print("β Connected to local Ollama (Mistral).") | |
| return llm | |
| except Exception as e: | |
| print(f"β οΈ Could not connect to Ollama locally: {e}") | |
| return None | |
| # --------------------------------------------------------------------- | |
| # π§ Shared Global Variable (used by dashboard tabs) | |
| # --------------------------------------------------------------------- | |
| latest_simulation_df = None | |
| # --------------------------------------------------------------------- | |
| # π¬ AI Chat Factory | |
| # --------------------------------------------------------------------- | |
| def ai_chat_factory(title: str = "AI Decision Assistant"): | |
| """ | |
| Creates a Gradio ChatInterface backed by either a local Ollama or | |
| Hugging Face Hub model that interprets 'what-if' business questions | |
| and triggers validated churn simulations. | |
| """ | |
| llm = get_llm() | |
| # ----------------------------------------------------------------- | |
| # π§ System Prompt β plain string (no format braces) | |
| # ----------------------------------------------------------------- | |
| system_prompt = ( | |
| "You are an AI Decision Assistant for churn prediction and simulation.\n\n" | |
| "The dataset has these features:\n" | |
| "- session_count: number of user sessions (numeric, 0β500)\n" | |
| "- recency: days since last app use (numeric, 0β365)\n" | |
| "- avg_session_duration: average session duration in minutes (numeric, 0β180)\n\n" | |
| "Return ONLY a JSON object with a top-level key 'plan' (array of operations).\n" | |
| "Do NOT include explanations or Markdown β only raw JSON.\n\n" | |
| "Each operation must have:\n" | |
| "- op: 'scale' | 'shift' | 'set' | 'clip'\n" | |
| "- col: feature name\n" | |
| "- value: '+10%', '-5', '1.2' (omit for clip)\n" | |
| "- where (optional): pandas-style filter, e.g. 'session_count > 10'\n" | |
| "- min/max (optional): numeric bounds for clip\n\n" | |
| "Examples:\n" | |
| "{ 'plan': [ {'op':'scale','col':'session_count','value':'+10%'} ] }\n" | |
| "{ 'plan': [ {'op':'shift','col':'recency','value':'-5','where':'session_count>10'} ] }\n" | |
| "{ 'plan': [ {'op':'clip','col':'recency','min':0,'max':90} ] }\n" | |
| ) | |
| # Prompt builder | |
| prompt_text = system_prompt + "\n\nUser query: {query}" | |
| # prompt = PromptTemplate(input_variables=["query"], template=prompt_text) | |
| def build_prompt(user_query: str) -> str: | |
| return system_prompt + "\n\nUser query: " + user_query | |
| # ----------------------------------------------------------------- | |
| # π¬ Chat Response Function | |
| # ----------------------------------------------------------------- | |
| def respond(message, history): | |
| global latest_simulation_df | |
| if llm is None: | |
| return ( | |
| "β οΈ LLM not connected.\n" | |
| "If running locally, start Ollama and pull the 'mistral' model.\n" | |
| "If on Hugging Face, ensure HF_TOKEN is set in Space Secrets." | |
| ) | |
| prompt_text = build_prompt(message) | |
| # Query the LLM | |
| try: | |
| raw = llm.invoke(prompt_text) | |
| raw_output = raw.content if hasattr(raw, "content") else str(raw) | |
| print("π LLM raw output:", raw_output) | |
| except Exception as e: | |
| return f"β οΈ LLM error: {e}" | |
| # Parse + validate JSON | |
| try: | |
| payload = json.loads(raw_output) | |
| plan_obj = Plan(**payload) | |
| except json.JSONDecodeError: | |
| return ( | |
| f"β οΈ Could not parse JSON.\nRaw model output:\n{raw_output}\n\n" | |
| "Try asking more precisely, e.g. 'Increase session_count by 10%'." | |
| ) | |
| except ValidationError as e: | |
| return f"β οΈ Invalid plan format:\n{e}" | |
| # Run simulation | |
| try: | |
| plan_dicts = [op.dict() for op in plan_obj.plan] | |
| result = simulate_plan(plan=plan_dicts) | |
| latest_simulation_df = result.get("df") | |
| try: | |
| log_simulation( | |
| user_query=message, | |
| plan=plan_dicts, | |
| metrics=result.get("metrics", {}) | |
| ) | |
| except Exception as e: | |
| print(f"β οΈ History logging failed: {e}") | |
| return result.get("summary", "β Simulation completed successfully.") | |
| except Exception as e: | |
| return f"β οΈ Simulation error: {e}" | |
| # ----------------------------------------------------------------- | |
| # ποΈ Build Chat Interface | |
| # ----------------------------------------------------------------- | |
| return respond, title | |