lfm2-transaction-encoder / encoder /src /demo /copilot_app_fraud_pattern.py
cdotsanghvi's picture
initial transaction co-pilot deployment
b3112c7
Raw
History Blame Contribute Delete
5.89 kB
"""Gradio app for the Fraud Pattern Co-Pilot demo.
Mirror of the Dispute/Collections apps. Cast strip → context →
Analyze → timeline + two-distribution scoreboard + streaming reasoning.
CLI:
python -m encoder.src.demo.copilot_app_fraud_pattern \\
--checkpoint encoder/experiments/fraud_pattern_v1/demo_checkpoint.pt \\
--port 7864
"""
from __future__ import annotations
import argparse
from pathlib import Path
import gradio as gr
import torch
from encoder.src.demo.copilot_inference_fraud_pattern import (
FraudPatternCastMember,
FraudPatternCopilotModel,
)
from encoder.src.demo.copilot_render_fraud_pattern import (
render_cast_strip,
render_context,
render_header,
render_reasoning,
render_timeline,
render_two_dist_scoreboard,
)
_CONTAINER_WIDTH_PX = 1200
def _build_tab(model: FraudPatternCopilotModel) -> None:
"""Build the Fraud surface into the current Gradio context."""
cast = model.cast
selected_idx = gr.State(value=0)
cast_html = gr.HTML(render_cast_strip(cast, 0))
with gr.Row():
cast_buttons: list[gr.Button] = []
for i, m in enumerate(cast):
short = " ".join(m.display_name.split(" ")[:2])
btn = gr.Button(
value=short,
variant="primary" if i == 0 else "secondary",
scale=1,
)
cast_buttons.append(btn)
context_html = gr.HTML(render_context(cast[0]))
with gr.Row():
gr.HTML("<div style='flex: 1'></div>")
analyze_btn = gr.Button(
value="Analyze",
variant="primary",
size="lg",
scale=0,
min_width=180,
)
gr.HTML("<div style='flex: 1'></div>")
timeline_html = gr.HTML(render_timeline(flagged_idx=cast[0].flagged_idx))
with gr.Row(equal_height=True):
with gr.Column(scale=3):
scoreboard_html = gr.HTML(render_two_dist_scoreboard(None))
with gr.Column(scale=2):
reasoning_html = gr.HTML(render_reasoning(None))
def _select(idx: int) -> tuple:
member = cast[idx]
button_updates = tuple(
gr.update(variant="primary" if i == idx else "secondary")
for i in range(len(cast))
)
return (
idx,
render_cast_strip(cast, idx),
render_context(member),
render_timeline(flagged_idx=member.flagged_idx),
render_two_dist_scoreboard(None),
render_reasoning(None),
) + button_updates
for i, btn in enumerate(cast_buttons):
btn.click(
fn=lambda i=i: _select(i),
inputs=None,
outputs=[
selected_idx, cast_html, context_html,
timeline_html, scoreboard_html, reasoning_html,
*cast_buttons,
],
)
def _analyze(idx: int):
member: FraudPatternCastMember = cast[idx]
result = model.predict(member, top_k=5)
timeline = render_timeline(
flagged_idx=member.flagged_idx,
top_k_positions=result.top_k_positions,
attribution_probs=result.attribution_probs,
)
scoreboard = render_two_dist_scoreboard(result)
yield timeline, scoreboard, render_reasoning("")
for partial in model.stream_reasoning(member, result, chunk_chars=6):
yield timeline, scoreboard, render_reasoning(partial)
analyze_btn.click(
fn=_analyze,
inputs=[selected_idx],
outputs=[timeline_html, scoreboard_html, reasoning_html],
)
def _build_ui(model: FraudPatternCopilotModel) -> gr.Blocks:
"""Standalone Blocks UI: Fraud-specific header + tab content."""
with gr.Blocks(title="Fraud Pattern Co-Pilot — Liquid AI") as demo:
gr.HTML(render_header())
_build_tab(model)
return demo
def main() -> None:
parser = argparse.ArgumentParser(
description="Fraud Pattern Co-Pilot Gradio demo",
)
parser.add_argument(
"--checkpoint",
type=Path,
default=Path("encoder/experiments/fraud_pattern_v1/demo_checkpoint.pt"),
)
parser.add_argument(
"--model-config",
type=Path,
default=Path("encoder/configs/model_fraud_pattern.yaml"),
)
parser.add_argument("--schema", type=Path, default=Path("data/schema.yaml"))
parser.add_argument(
"--histories",
type=Path,
default=Path("data/synthetic/token_ids.npy"),
)
parser.add_argument(
"--cast",
type=Path,
default=Path("encoder/data/fraud_pattern_cast.json"),
)
parser.add_argument(
"--device",
type=str,
default="cpu",
choices=["cpu", "cuda", "mps"],
)
parser.add_argument("--port", type=int, default=7864)
parser.add_argument("--share", action="store_true")
args = parser.parse_args()
device = torch.device(args.device)
print(f"Loading FraudPatternCopilotModel on {device} ...")
model = FraudPatternCopilotModel.from_paths(
checkpoint_path=args.checkpoint,
model_config_path=args.model_config,
schema_path=args.schema,
histories_path=args.histories,
cast_path=args.cast,
device=device,
)
print(f" cast size: {len(model.cast)}")
demo = _build_ui(model)
demo.queue().launch(
server_name="0.0.0.0",
server_port=args.port,
share=args.share,
theme=gr.themes.Default(
font=["Inter", "system-ui", "sans-serif"],
font_mono=["JetBrains Mono", "ui-monospace", "monospace"],
),
css=f"""
.gradio-container {{
max-width: {_CONTAINER_WIDTH_PX}px !important;
background: #fafafa !important;
}}
""",
)
if __name__ == "__main__":
main()