sakeef's picture
Upload dialog_manager.py with huggingface_hub
1cc481e verified
"""
Dialog Manager — Conversation State Tracking & Context Management
=================================================================
Manages multi-turn dialog state for Bengali public service conversations.
Responsibilities:
- Track conversation history (user + agent turns)
- Maintain slot/entity memory across turns
- Determine dialog acts (greet → inform → query → confirm → close)
- Build context windows for the response generation model
- Handle domain switching and topic transitions
Designed to work with:
- JointIntentNER model (NLU component)
- BanglaT5 response generation (NLG component)
"""
import json
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, field, asdict
from enum import Enum
from datetime import datetime
# ============================================================================
# DIALOG STATES
# ============================================================================
class DialogState(Enum):
"""High-level dialog states."""
IDLE = "idle" # No active conversation
GREETING = "greeting" # Initial greeting phase
INFORMATION = "information" # Providing/collecting information
QUERY = "query" # User asking questions
CONFIRMATION = "confirmation" # Confirming details
CLOSING = "closing" # Wrapping up conversation
ESCALATION = "escalation" # Needs human agent
# Intent-to-state mapping
INTENT_STATE_MAP = {
"greeting": DialogState.GREETING,
"farewell": DialogState.CLOSING,
"thanks": DialogState.CLOSING,
"passport_application": DialogState.INFORMATION,
"passport_renewal": DialogState.INFORMATION,
"passport_status": DialogState.QUERY,
"passport_fee": DialogState.QUERY,
"nid_application": DialogState.INFORMATION,
"nid_correction": DialogState.INFORMATION,
"nid_status": DialogState.QUERY,
"utility_bill_payment": DialogState.INFORMATION,
"utility_new_connection": DialogState.INFORMATION,
"utility_complaint": DialogState.QUERY,
"welfare_application": DialogState.INFORMATION,
"welfare_eligibility": DialogState.QUERY,
"welfare_status": DialogState.QUERY,
"general_inquiry": DialogState.QUERY,
"complaint": DialogState.ESCALATION,
}
# ============================================================================
# SLOT DEFINITIONS PER DOMAIN
# ============================================================================
DOMAIN_SLOTS = {
"passport": [
"applicant_name", "nid_number", "date_of_birth",
"passport_type", "application_type", "fee_amount",
],
"nid": [
"applicant_name", "date_of_birth", "voter_area",
"correction_field", "nid_number",
],
"utilities": [
"account_number", "bill_type", "payment_method",
"complaint_type", "connection_type", "area",
],
"welfare": [
"applicant_name", "age", "scheme_name",
"eligibility_status", "application_id",
],
"general": [
"topic", "query_type",
],
}
# ============================================================================
# DATA STRUCTURES
# ============================================================================
@dataclass
class Turn:
"""A single turn in the conversation."""
role: str # "citizen" or "agent"
text: str # The utterance text
intent: Optional[str] = None
entities: Optional[Dict[str, str]] = None
timestamp: Optional[str] = None
def to_dict(self) -> Dict:
return asdict(self)
@dataclass
class ConversationState:
"""Full state of a conversation."""
dialog_id: str
domain: str = "general"
state: DialogState = DialogState.IDLE
turns: List[Turn] = field(default_factory=list)
slots: Dict[str, Optional[str]] = field(default_factory=dict)
turn_count: int = 0
confidence_scores: List[float] = field(default_factory=list)
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> Dict:
d = asdict(self)
d["state"] = self.state.value
return d
# ============================================================================
# DIALOG MANAGER
# ============================================================================
class DialogManager:
"""
Manages dialog state for multi-turn Bengali public service conversations.
The dialog manager sits between NLU (intent + entities) and NLG (response
generation), maintaining conversation context and determining the system's
next action.
Architecture:
User Input → NLU → DialogManager.update() → context → NLG → Response
"""
def __init__(self, max_context_turns: int = 5, max_turns: int = 20):
"""
Args:
max_context_turns: Number of recent turns to include in context
window for response generation.
max_turns: Maximum turns before suggesting escalation.
"""
self.max_context_turns = max_context_turns
self.max_turns = max_turns
self.conversations: Dict[str, ConversationState] = {}
def start_conversation(self, dialog_id: str, domain: str = "general") -> ConversationState:
"""Initialize a new conversation."""
conv = ConversationState(
dialog_id=dialog_id,
domain=domain,
state=DialogState.IDLE,
slots={slot: None for slot in DOMAIN_SLOTS.get(domain, [])},
)
self.conversations[dialog_id] = conv
return conv
def get_conversation(self, dialog_id: str) -> Optional[ConversationState]:
"""Retrieve an existing conversation."""
return self.conversations.get(dialog_id)
def update(
self,
dialog_id: str,
user_text: str,
intent: str,
entities: Dict[str, str],
confidence: float = 1.0,
) -> Tuple[ConversationState, str]:
"""
Process a user turn and update dialog state.
Args:
dialog_id: Conversation identifier
user_text: The user's utterance
intent: Predicted intent from NLU
entities: Extracted entities from NLU {entity_type: value}
confidence: Intent classification confidence score
Returns:
(updated_state, context_for_nlg)
"""
conv = self.conversations.get(dialog_id)
if conv is None:
conv = self.start_conversation(dialog_id)
# 1. Record the user turn
user_turn = Turn(
role="citizen",
text=user_text,
intent=intent,
entities=entities if entities else None,
timestamp=datetime.now().isoformat(),
)
conv.turns.append(user_turn)
conv.turn_count += 1
conv.confidence_scores.append(confidence)
# 2. Update domain based on intent (if domain-specific)
new_domain = self._infer_domain(intent)
if new_domain and new_domain != conv.domain:
conv.domain = new_domain
# Re-initialize slots for new domain
conv.slots = {slot: None for slot in DOMAIN_SLOTS.get(new_domain, [])}
# 3. Update dialog state
conv.state = self._transition_state(conv, intent, confidence)
# 4. Fill slots from entities
self._fill_slots(conv, entities)
# 5. Build context for response generation
context = self._build_context(conv)
return conv, context
def add_agent_response(self, dialog_id: str, response_text: str):
"""Record the agent's response in conversation history."""
conv = self.conversations.get(dialog_id)
if conv is None:
return
agent_turn = Turn(
role="agent",
text=response_text,
timestamp=datetime.now().isoformat(),
)
conv.turns.append(agent_turn)
def get_filled_slots(self, dialog_id: str) -> Dict[str, str]:
"""Return slots that have been filled."""
conv = self.conversations.get(dialog_id)
if conv is None:
return {}
return {k: v for k, v in conv.slots.items() if v is not None}
def get_missing_slots(self, dialog_id: str) -> List[str]:
"""Return slots that still need to be filled."""
conv = self.conversations.get(dialog_id)
if conv is None:
return []
return [k for k, v in conv.slots.items() if v is None]
def should_escalate(self, dialog_id: str) -> bool:
"""Check if conversation should be escalated to human agent."""
conv = self.conversations.get(dialog_id)
if conv is None:
return False
# Escalate if: explicit complaint, too many turns, or low confidence
if conv.state == DialogState.ESCALATION:
return True
if conv.turn_count > self.max_turns:
return True
if len(conv.confidence_scores) >= 3:
recent = conv.confidence_scores[-3:]
if all(c < 0.5 for c in recent):
return True
return False
def end_conversation(self, dialog_id: str) -> Optional[Dict]:
"""End a conversation and return its summary."""
conv = self.conversations.pop(dialog_id, None)
if conv is None:
return None
return {
"dialog_id": dialog_id,
"domain": conv.domain,
"total_turns": conv.turn_count,
"final_state": conv.state.value,
"filled_slots": self.get_filled_slots(dialog_id),
"avg_confidence": (
sum(conv.confidence_scores) / len(conv.confidence_scores)
if conv.confidence_scores else 0
),
}
# ------------------------------------------------------------------
# Internal Methods
# ------------------------------------------------------------------
def _infer_domain(self, intent: str) -> Optional[str]:
"""Infer domain from intent name."""
if intent.startswith("passport"):
return "passport"
elif intent.startswith("nid"):
return "nid"
elif intent.startswith("utility"):
return "utilities"
elif intent.startswith("welfare"):
return "welfare"
return None
def _transition_state(
self, conv: ConversationState, intent: str, confidence: float
) -> DialogState:
"""Determine next dialog state based on current state + intent."""
# Low confidence → stay in current state (don't make wrong transitions)
if confidence < 0.3:
return conv.state
# Map intent to target state
target = INTENT_STATE_MAP.get(intent, DialogState.QUERY)
# State transition rules
current = conv.state
if current == DialogState.IDLE:
return target
if current == DialogState.GREETING:
# After greeting, move to whatever the user wants
if target in (DialogState.GREETING, DialogState.CLOSING):
return target
return target
if current == DialogState.CLOSING:
# If user continues after farewell, re-open
if target not in (DialogState.CLOSING,):
return target
return DialogState.CLOSING
# Default: follow the intent mapping
return target
def _fill_slots(self, conv: ConversationState, entities: Dict[str, str]):
"""Fill conversation slots from extracted entities."""
if not entities:
return
# Map NER entity types to slot names
entity_slot_map = {
"PERSON": "applicant_name",
"NID": "nid_number",
"DATE": "date_of_birth",
"MONEY": "fee_amount",
"LOCATION": "area",
"ACCOUNT": "account_number",
"AGE": "age",
"SCHEME": "scheme_name",
"DOCUMENT": "passport_type",
}
for entity_type, value in entities.items():
slot_name = entity_slot_map.get(entity_type)
if slot_name and slot_name in conv.slots:
conv.slots[slot_name] = value
def _build_context(self, conv: ConversationState) -> str:
"""
Build context string for response generation model.
Takes the last N turns and formats them as the model expects.
"""
# Get recent turns (up to max_context_turns)
recent_turns = conv.turns[-self.max_context_turns:]
# Format as "role: text" pairs
context_parts = []
for turn in recent_turns:
if turn.role == "citizen":
context_parts.append(f"নাগরিক: {turn.text}")
else:
context_parts.append(f"এজেন্ট: {turn.text}")
return " ".join(context_parts)
def get_state_summary(self, dialog_id: str) -> Dict:
"""Get a summary of current conversation state (for debugging/logging)."""
conv = self.conversations.get(dialog_id)
if conv is None:
return {"error": "Conversation not found"}
return {
"dialog_id": dialog_id,
"domain": conv.domain,
"state": conv.state.value,
"turn_count": conv.turn_count,
"filled_slots": {k: v for k, v in conv.slots.items() if v is not None},
"missing_slots": [k for k, v in conv.slots.items() if v is None],
"should_escalate": self.should_escalate(dialog_id),
"last_intent": (
conv.turns[-1].intent if conv.turns and conv.turns[-1].intent else None
),
}