File size: 7,772 Bytes
e41598d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import streamlit as st
from faiss import IndexFlatL2
import numpy as np
from transformers import T5Tokenizer, T5ForConditionalGeneration
from graphviz import Digraph
import torch
# Initialize T5 model for both summarization and feature extraction
tokenizer = T5Tokenizer.from_pretrained("t5-large")
model = T5ForConditionalGeneration.from_pretrained("t5-large")
# Initialize FAISS index
index = IndexFlatL2(1024) # T5 embeddings are 1024-dimensional
responses = [] # Store responses for FAISS
# Streamlit page setup
st.title("Banking AI Decision-Making Assistant π€")
st.markdown("""
### **Understanding AI Reasoning Methods**
Before answering, hereβs how AI chooses the best reasoning method for your banking use case:
| Method | How It Works | Best For | Example Use Cases |
|--------|-------------|----------|------------------|
| **Chain-of-Thought (CoT)** | Step-by-step reasoning | Math, logic, coding | Solving word problems, code generation |
| **Tree-of-Thoughts (ToT)** | Explores multiple solution paths | Games, decision-making | Chess, strategic planning |
| **Self-Consistency (SC)** | Selects the most frequent correct answer | Fact-checking, accuracy | Medical diagnosis, legal cases |
| **PAL (Program-Aided LMs)** | Uses external tools for precise answers | Math, finance, databases | Financial projections, data queries |
| **ReAct (Reasoning + Acting)** | AI interacts with tools & takes actions | AI Agents, automation | AI assistants, automated workflows |
| **Graph-of-Thoughts (GoT)** | Thoughts form a flexible network | Research, innovation | Scientific discovery, brainstorming |
""")
# Collect use case from the user
st.markdown("### **Describe Your Banking Use Case:**")
use_case = st.text_area("Enter your banking use case:")
# Initialize session state for responses
if 'responses' not in st.session_state:
st.session_state.responses = {}
# Collect responses from user - checkboxes for independent selection
st.session_state.responses['multiple_factors'] = st.checkbox(
"π Does this involve multiple decision factors? (e.g., risk, compliance, fraud)",
value=False,
key="multiple_factors"
)
st.session_state.responses['real_time_validation'] = st.checkbox(
"β³ Does this require real-time validation? (e.g., fraud detection, transaction monitoring)",
value=False,
key="real_time_validation"
)
st.session_state.responses['user_feedback'] = st.checkbox(
"π₯ Does this need user feedback handling? (e.g., customer disputes, support tickets)",
value=False,
key="user_feedback"
)
st.session_state.responses['complexity'] = st.checkbox(
"π§© Is the decision-making process complex? (e.g., multi-step approvals, AI model predictions)",
value=False,
key="complexity"
)
st.session_state.responses['security_concern'] = st.checkbox(
"π Are there security concerns? (e.g., sensitive data, encryption, compliance)",
value=False,
key="security_concern"
)
st.session_state.responses['automation_level'] = st.checkbox(
"π€ Is this process fully automated? (e.g., auto-loan approvals, AI-driven compliance checks)",
value=False,
key="automation_level"
)
# Function to determine the best AI method based on responses
def determine_method(responses):
"""Determines the best AI method based on user responses."""
if responses['multiple_factors'] and responses['complexity']:
rationale = "Yes, this requires multi-step decision-making, strategic planning."
return "Tree-of-Thoughts (ToT)", rationale
elif responses['real_time_validation']:
rationale = "Yes, this requires real-time data validation for fraud detection or monitoring."
return "PAL (Program-Aided LMs)", rationale
elif responses['user_feedback']:
rationale = "Yes, this involves dynamic user feedback handling."
return "ReAct (Reasoning + Acting)", rationale
elif responses['security_concern']:
rationale = "Yes, there are concerns regarding data security and accuracy."
return "Self-Consistency (SC)", rationale
elif responses['automation_level']:
rationale = "Yes, this process requires a fully automated system with external tools."
return "PAL (Program-Aided LMs)", rationale
else:
rationale = "No, this decision-making is more straightforward and does not involve complex factors."
return "Chain-of-Thought (CoT)", rationale
# Function to store responses in FAISS index
def store_response(responses):
"""Stores user responses in FAISS."""
response_str = " ".join([f"{key}: {value}" for key, value in responses.items()])
inputs = tokenizer(response_str, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad(): # Disable gradient calculation for inference
outputs = model.encoder(inputs["input_ids"]) # Encoder for feature extraction
embeddings = outputs.last_hidden_state.mean(dim=1).detach().numpy() # Use mean of hidden states as embedding
# Add the embeddings to the FAISS index
index.add(np.array(embeddings))
# Function to generate a decision tree visualization
def visualize_decision_tree(responses, selected_method, rationale):
"""Generates a decision tree visualization using Graphviz."""
dot = Digraph()
dot.node("Use Case Input", "π¦ Banking Use Case")
dot.node("Multiple Decision Factors", f"Yes: {responses['multiple_factors']}" if responses['multiple_factors'] else "No")
dot.node("Real-Time Validation", f"Yes: {responses['real_time_validation']}" if responses['real_time_validation'] else "No")
dot.node("User Feedback Handling", f"Yes: {responses['user_feedback']}" if responses['user_feedback'] else "No")
dot.node("Complexity", f"High: {responses['complexity']}" if responses['complexity'] else "Low")
dot.node("Security Concern", f"Yes: {responses['security_concern']}" if responses['security_concern'] else "No")
dot.node("Automation Level", f"Automated: {responses['automation_level']}" if responses['automation_level'] else "Human Oversight")
dot.node("Final Method", f"π― {selected_method}\nRationale: {rationale}")
# Connect nodes
dot.edge("Use Case Input", "Multiple Decision Factors")
dot.edge("Multiple Decision Factors", "Real-Time Validation")
dot.edge("Real-Time Validation", "User Feedback Handling")
dot.edge("User Feedback Handling", "Complexity")
dot.edge("Complexity", "Security Concern")
dot.edge("Security Concern", "Automation Level")
dot.edge("Automation Level", "Final Method")
st.graphviz_chart(dot)
# Summarization using T5
def get_summary(use_case):
"""Generates a summary using T5."""
try:
inputs = tokenizer(f"summarize: {use_case}", return_tensors="pt", max_length=512, truncation=True)
summary_ids = model.generate(inputs["input_ids"], max_length=200, num_beams=4, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
except Exception as e:
st.error(f"β Error generating summary: {e}")
return None
# Analyze button
if st.button("π Analyze Use Case"):
# Get summary of the use case
summary = get_summary(use_case)
if summary:
st.write("### **π Summary of Your Use Case:**")
st.write(summary)
# Determine the best AI method
method, rationale = determine_method(st.session_state.responses)
st.write(f"## π Recommended AI Method: {method}")
st.write(f"### Reasoning: {rationale}")
# Store response in FAISS index
store_response(st.session_state.responses)
# Visualize the decision tree
visualize_decision_tree(st.session_state.responses, method, rationale) |