user-churn / tabs /shared_ai.py
VasithaTilakumara
change local llm to huggingface inference API
8ea584a
"""
tabs/shared_ai.py β€” Version 3.0
AI Decision Assistant for Churn Simulation
------------------------------------------
β€’ Works both locally (Ollama) and on Hugging Face (Hub Inference)
β€’ Uses validated JSON schema (Pydantic)
β€’ Adds context about available features
β€’ Calls dynamic simulation engine (simulate_plan)
β€’ Provides user-friendly error feedback
"""
import os
import json
import gradio as gr
from typing import List, Literal, Optional
from pydantic import BaseModel, ValidationError
# LangChain components
from langchain_community.chat_models import ChatOllama, ChatHuggingFace
# from langchain_community.llms import HuggingFaceHub
from langchain_community.chat_models import ChatOllama
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
# from langchain.prompts import PromptTemplate
# Local utilities
from utils.scenario_engine_ng import simulate_plan
from utils.history import log_simulation
# ---------------------------------------------------------------------
# 🧩 Pydantic Schemas
# ---------------------------------------------------------------------
class Operation(BaseModel):
op: Literal["scale", "shift", "set", "clip"]
col: str
value: Optional[str] = None
where: Optional[str] = None
min: Optional[float] = None
max: Optional[float] = None
class Plan(BaseModel):
plan: List[Operation]
# ---------------------------------------------------------------------
# βš™οΈ Environment-Aware LLM Factory
# ---------------------------------------------------------------------
def get_llm():
"""Return an LLM instance depending on environment (local or Hugging Face)."""
if os.getenv("SPACE_ID"):
print("🌐 Detected Hugging Face environment β€” using HuggingFaceEndpoint model.")
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
print("⚠️ HF_TOKEN not found β€” please add it in Space secrets.")
return None
try:
llm = ChatHuggingFace(
llm=HuggingFaceEndpoint(
repo_id="mistralai/Mistral-7B-Instruct-v0.2",
huggingfacehub_api_token=HF_TOKEN,
temperature=0.3,
max_new_tokens=512
)
)
print("βœ… Connected to Hugging Face Endpoint model.")
return llm
except Exception as e:
print(f"⚠️ Failed to connect to Hugging Face Endpoint: {e}")
return None
else:
# local Ollama fallback
try:
llm = ChatOllama(model="mistral")
print("βœ… Connected to local Ollama (Mistral).")
return llm
except Exception as e:
print(f"⚠️ Could not connect to Ollama locally: {e}")
return None
# ---------------------------------------------------------------------
# 🧠 Shared Global Variable (used by dashboard tabs)
# ---------------------------------------------------------------------
latest_simulation_df = None
# ---------------------------------------------------------------------
# πŸ’¬ AI Chat Factory
# ---------------------------------------------------------------------
def ai_chat_factory(title: str = "AI Decision Assistant"):
"""
Creates a Gradio ChatInterface backed by either a local Ollama or
Hugging Face Hub model that interprets 'what-if' business questions
and triggers validated churn simulations.
"""
llm = get_llm()
# -----------------------------------------------------------------
# 🧠 System Prompt β€” plain string (no format braces)
# -----------------------------------------------------------------
system_prompt = (
"You are an AI Decision Assistant for churn prediction and simulation.\n\n"
"The dataset has these features:\n"
"- session_count: number of user sessions (numeric, 0–500)\n"
"- recency: days since last app use (numeric, 0–365)\n"
"- avg_session_duration: average session duration in minutes (numeric, 0–180)\n\n"
"Return ONLY a JSON object with a top-level key 'plan' (array of operations).\n"
"Do NOT include explanations or Markdown β€” only raw JSON.\n\n"
"Each operation must have:\n"
"- op: 'scale' | 'shift' | 'set' | 'clip'\n"
"- col: feature name\n"
"- value: '+10%', '-5', '1.2' (omit for clip)\n"
"- where (optional): pandas-style filter, e.g. 'session_count > 10'\n"
"- min/max (optional): numeric bounds for clip\n\n"
"Examples:\n"
"{ 'plan': [ {'op':'scale','col':'session_count','value':'+10%'} ] }\n"
"{ 'plan': [ {'op':'shift','col':'recency','value':'-5','where':'session_count>10'} ] }\n"
"{ 'plan': [ {'op':'clip','col':'recency','min':0,'max':90} ] }\n"
)
# Prompt builder
prompt_text = system_prompt + "\n\nUser query: {query}"
# prompt = PromptTemplate(input_variables=["query"], template=prompt_text)
def build_prompt(user_query: str) -> str:
return system_prompt + "\n\nUser query: " + user_query
# -----------------------------------------------------------------
# πŸ’¬ Chat Response Function
# -----------------------------------------------------------------
def respond(message, history):
global latest_simulation_df
if llm is None:
return (
"⚠️ LLM not connected.\n"
"If running locally, start Ollama and pull the 'mistral' model.\n"
"If on Hugging Face, ensure HF_TOKEN is set in Space Secrets."
)
prompt_text = build_prompt(message)
# Query the LLM
try:
raw = llm.invoke(prompt_text)
raw_output = raw.content if hasattr(raw, "content") else str(raw)
print("πŸ” LLM raw output:", raw_output)
except Exception as e:
return f"⚠️ LLM error: {e}"
# Parse + validate JSON
try:
payload = json.loads(raw_output)
plan_obj = Plan(**payload)
except json.JSONDecodeError:
return (
f"⚠️ Could not parse JSON.\nRaw model output:\n{raw_output}\n\n"
"Try asking more precisely, e.g. 'Increase session_count by 10%'."
)
except ValidationError as e:
return f"⚠️ Invalid plan format:\n{e}"
# Run simulation
try:
plan_dicts = [op.dict() for op in plan_obj.plan]
result = simulate_plan(plan=plan_dicts)
latest_simulation_df = result.get("df")
try:
log_simulation(
user_query=message,
plan=plan_dicts,
metrics=result.get("metrics", {})
)
except Exception as e:
print(f"⚠️ History logging failed: {e}")
return result.get("summary", "βœ… Simulation completed successfully.")
except Exception as e:
return f"⚠️ Simulation error: {e}"
# -----------------------------------------------------------------
# πŸŽ›οΈ Build Chat Interface
# -----------------------------------------------------------------
return respond, title