""" tabs/shared_ai.py — Version 3.0 AI Decision Assistant for Churn Simulation ------------------------------------------ • Works both locally (Ollama) and on Hugging Face (Hub Inference) • Uses validated JSON schema (Pydantic) • Adds context about available features • Calls dynamic simulation engine (simulate_plan) • Provides user-friendly error feedback """ import os import json import gradio as gr from typing import List, Literal, Optional from pydantic import BaseModel, ValidationError # LangChain components from langchain_community.chat_models import ChatOllama, ChatHuggingFace # from langchain_community.llms import HuggingFaceHub from langchain_community.chat_models import ChatOllama from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint # from langchain.prompts import PromptTemplate # Local utilities from utils.scenario_engine_ng import simulate_plan from utils.history import log_simulation # --------------------------------------------------------------------- # 🧩 Pydantic Schemas # --------------------------------------------------------------------- class Operation(BaseModel): op: Literal["scale", "shift", "set", "clip"] col: str value: Optional[str] = None where: Optional[str] = None min: Optional[float] = None max: Optional[float] = None class Plan(BaseModel): plan: List[Operation] # --------------------------------------------------------------------- # ⚙️ Environment-Aware LLM Factory # --------------------------------------------------------------------- def get_llm(): """Return an LLM instance depending on environment (local or Hugging Face).""" if os.getenv("SPACE_ID"): print("🌐 Detected Hugging Face environment — using HuggingFaceEndpoint model.") HF_TOKEN = os.getenv("HF_TOKEN") if not HF_TOKEN: print("⚠️ HF_TOKEN not found — please add it in Space secrets.") return None try: llm = ChatHuggingFace( llm=HuggingFaceEndpoint( repo_id="mistralai/Mistral-7B-Instruct-v0.2", huggingfacehub_api_token=HF_TOKEN, temperature=0.3, max_new_tokens=512 ) ) print("✅ Connected to Hugging Face Endpoint model.") return llm except Exception as e: print(f"⚠️ Failed to connect to Hugging Face Endpoint: {e}") return None else: # local Ollama fallback try: llm = ChatOllama(model="mistral") print("✅ Connected to local Ollama (Mistral).") return llm except Exception as e: print(f"⚠️ Could not connect to Ollama locally: {e}") return None # --------------------------------------------------------------------- # 🧠 Shared Global Variable (used by dashboard tabs) # --------------------------------------------------------------------- latest_simulation_df = None # --------------------------------------------------------------------- # 💬 AI Chat Factory # --------------------------------------------------------------------- def ai_chat_factory(title: str = "AI Decision Assistant"): """ Creates a Gradio ChatInterface backed by either a local Ollama or Hugging Face Hub model that interprets 'what-if' business questions and triggers validated churn simulations. """ llm = get_llm() # ----------------------------------------------------------------- # 🧠 System Prompt — plain string (no format braces) # ----------------------------------------------------------------- system_prompt = ( "You are an AI Decision Assistant for churn prediction and simulation.\n\n" "The dataset has these features:\n" "- session_count: number of user sessions (numeric, 0–500)\n" "- recency: days since last app use (numeric, 0–365)\n" "- avg_session_duration: average session duration in minutes (numeric, 0–180)\n\n" "Return ONLY a JSON object with a top-level key 'plan' (array of operations).\n" "Do NOT include explanations or Markdown — only raw JSON.\n\n" "Each operation must have:\n" "- op: 'scale' | 'shift' | 'set' | 'clip'\n" "- col: feature name\n" "- value: '+10%', '-5', '1.2' (omit for clip)\n" "- where (optional): pandas-style filter, e.g. 'session_count > 10'\n" "- min/max (optional): numeric bounds for clip\n\n" "Examples:\n" "{ 'plan': [ {'op':'scale','col':'session_count','value':'+10%'} ] }\n" "{ 'plan': [ {'op':'shift','col':'recency','value':'-5','where':'session_count>10'} ] }\n" "{ 'plan': [ {'op':'clip','col':'recency','min':0,'max':90} ] }\n" ) # Prompt builder prompt_text = system_prompt + "\n\nUser query: {query}" # prompt = PromptTemplate(input_variables=["query"], template=prompt_text) def build_prompt(user_query: str) -> str: return system_prompt + "\n\nUser query: " + user_query # ----------------------------------------------------------------- # 💬 Chat Response Function # ----------------------------------------------------------------- def respond(message, history): global latest_simulation_df if llm is None: return ( "⚠️ LLM not connected.\n" "If running locally, start Ollama and pull the 'mistral' model.\n" "If on Hugging Face, ensure HF_TOKEN is set in Space Secrets." ) prompt_text = build_prompt(message) # Query the LLM try: raw = llm.invoke(prompt_text) raw_output = raw.content if hasattr(raw, "content") else str(raw) print("🔍 LLM raw output:", raw_output) except Exception as e: return f"⚠️ LLM error: {e}" # Parse + validate JSON try: payload = json.loads(raw_output) plan_obj = Plan(**payload) except json.JSONDecodeError: return ( f"⚠️ Could not parse JSON.\nRaw model output:\n{raw_output}\n\n" "Try asking more precisely, e.g. 'Increase session_count by 10%'." ) except ValidationError as e: return f"⚠️ Invalid plan format:\n{e}" # Run simulation try: plan_dicts = [op.dict() for op in plan_obj.plan] result = simulate_plan(plan=plan_dicts) latest_simulation_df = result.get("df") try: log_simulation( user_query=message, plan=plan_dicts, metrics=result.get("metrics", {}) ) except Exception as e: print(f"⚠️ History logging failed: {e}") return result.get("summary", "✅ Simulation completed successfully.") except Exception as e: return f"⚠️ Simulation error: {e}" # ----------------------------------------------------------------- # 🎛️ Build Chat Interface # ----------------------------------------------------------------- return respond, title