Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import json | |
| from typing import List, Dict | |
| import os | |
| from dotenv import load_dotenv | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from anthropic import Anthropic | |
| import time | |
| # Import our modules | |
| from src.invoice_generator import InvoiceGenerator | |
| from src.vector_store import ContractVectorStore | |
| # Load environment variables | |
| load_dotenv() | |
| # Page configuration | |
| st.set_page_config( | |
| page_title="Enterprise Pricing Audit Assistant", | |
| page_icon="π°", | |
| layout="wide" | |
| ) | |
| # Load custom CSS | |
| def load_css(): | |
| with open("styles.css") as f: | |
| st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
| # Initialize LLM client | |
| def init_llm(): | |
| return Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) | |
| # Initialize the sentence transformer model | |
| def load_embedding_model(): | |
| from sentence_transformers import SentenceTransformer | |
| return SentenceTransformer('all-MiniLM-L6-v2') | |
| def analyze_invoice_with_rag(invoice: Dict, contract: Dict, vector_store: ContractVectorStore) -> Dict: | |
| base_rate = contract["terms"]["base_rate"] | |
| quantity = invoice["quantity"] | |
| charged_amount = invoice["amount_charged"] | |
| correct_amount = invoice["correct_amount"] | |
| # Search for relevant contract terms | |
| relevant_terms = vector_store.search_relevant_terms( | |
| f"pricing rules for quantity {quantity} and amount {charged_amount}" | |
| ) | |
| # Prepare context for LLM | |
| context = { | |
| "invoice_details": { | |
| "invoice_id": invoice["invoice_id"], | |
| "quantity": quantity, | |
| "charged_amount": charged_amount, | |
| "correct_amount": correct_amount, | |
| "date": invoice["date"] | |
| }, | |
| "relevant_terms": [term["text"] for term in relevant_terms], | |
| "discrepancy": round(charged_amount - correct_amount, 2), | |
| "discrepancy_percentage": round((charged_amount - correct_amount) / correct_amount * 100, 2) | |
| } | |
| # Generate explanation using LLM if there's a discrepancy | |
| if abs(context["discrepancy"]) > 0.01: | |
| prompt = f""" | |
| Analyze this invoice for pricing accuracy: | |
| Invoice Details: | |
| - Invoice ID: {context['invoice_details']['invoice_id']} | |
| - Quantity: {context['invoice_details']['quantity']} | |
| - Charged Amount: ${context['invoice_details']['charged_amount']:.2f} | |
| - Correct Amount: ${context['invoice_details']['correct_amount']:.2f} | |
| - Date: {context['invoice_details']['date']} | |
| Relevant Contract Terms: | |
| {chr(10).join('- ' + term for term in context['relevant_terms'])} | |
| Discrepancy found: | |
| - Amount Difference: ${context['discrepancy']:.2f} | |
| - Percentage Difference: {context['discrepancy_percentage']:.2f}% | |
| Please provide a detailed explanation of: | |
| 1. Why there is a pricing discrepancy | |
| 2. Which contract terms were violated | |
| 3. How the correct price should have been calculated | |
| Keep the explanation clear and concise, focusing on the specific pricing rules that were not properly applied. | |
| """ | |
| anthropic = init_llm() | |
| response = anthropic.messages.create( | |
| model="claude-3-sonnet-20240229", | |
| max_tokens=1000, | |
| messages=[{"role": "user", "content": prompt}] | |
| ) | |
| explanation = response.content[0].text | |
| else: | |
| explanation = "Invoice pricing is correct according to contract terms." | |
| return { | |
| **context, | |
| "explanation": explanation, | |
| "relevant_terms": relevant_terms | |
| } | |
| def display_metrics(invoices_df): | |
| with st.container(): | |
| st.markdown('<div class="metrics-container">', unsafe_allow_html=True) | |
| col1, col2, col3, col4 = st.columns(4) | |
| total_invoices = len(invoices_df) | |
| incorrect_invoices = len(invoices_df[invoices_df['has_error']]) | |
| total_value = invoices_df['amount_charged'].sum() | |
| total_discrepancy = (invoices_df['amount_charged'] - invoices_df['correct_amount']).sum() | |
| with col1: | |
| st.metric("Total Invoices", total_invoices) | |
| with col2: | |
| st.metric("Incorrect Invoices", incorrect_invoices) | |
| with col3: | |
| st.metric("Total Invoice Value", f"${total_value:,.2f}") | |
| with col4: | |
| st.metric("Total Pricing Discrepancy", f"${total_discrepancy:,.2f}") | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| def display_invoice_tables(invoices_df): | |
| st.markdown('<div class="invoice-table">', unsafe_allow_html=True) | |
| # Separate correct and incorrect invoices | |
| correct_invoices = invoices_df[~invoices_df['has_error']].copy() | |
| incorrect_invoices = invoices_df[invoices_df['has_error']].copy() | |
| # Format currency columns | |
| currency_cols = ['amount_charged', 'correct_amount'] | |
| for df in [correct_invoices, incorrect_invoices]: | |
| for col in currency_cols: | |
| df[col] = df[col].apply(lambda x: f"${x:,.2f}") | |
| # Display tables in tabs | |
| tab1, tab2 = st.tabs(["π’ Correct Invoices", "π΄ Incorrect Invoices"]) | |
| with tab1: | |
| if not correct_invoices.empty: | |
| st.dataframe( | |
| correct_invoices, | |
| column_config={ | |
| "invoice_id": "Invoice ID", | |
| "date": "Date", | |
| "quantity": "Quantity", | |
| "amount_charged": "Amount", | |
| }, | |
| hide_index=True | |
| ) | |
| else: | |
| st.info("No correctly priced invoices found.") | |
| with tab2: | |
| if not incorrect_invoices.empty: | |
| st.dataframe( | |
| incorrect_invoices, | |
| column_config={ | |
| "invoice_id": "Invoice ID", | |
| "date": "Date", | |
| "quantity": "Quantity", | |
| "amount_charged": "Charged Amount", | |
| "correct_amount": "Correct Amount" | |
| }, | |
| hide_index=True | |
| ) | |
| else: | |
| st.info("No pricing discrepancies found.") | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| def display_contract_details(contract): | |
| st.markdown('<div class="contract-details">', unsafe_allow_html=True) | |
| st.subheader("π Contract Details") | |
| # Basic contract information | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.write("**Contract ID:**", contract['contract_id']) | |
| with col2: | |
| st.write("**Client:**", contract['client']) | |
| with col3: | |
| st.write("**Base Rate:**", f"${contract['terms']['base_rate']}") | |
| # Pricing rules | |
| with st.expander("π·οΈ Pricing Rules"): | |
| if "volume_discounts" in contract["terms"]: | |
| st.write("**Volume Discounts:**") | |
| for discount in contract["terms"]["volume_discounts"]: | |
| st.write(f"β’ {discount['discount']*100}% off for quantities β₯ {discount['threshold']:,}") | |
| if "tiered_pricing" in contract["terms"]: | |
| st.write("**Tiered Pricing:**") | |
| for tier in contract["terms"]["tiered_pricing"]: | |
| st.write(f"β’ {tier['tier']}: {tier['rate']}x base rate") | |
| # Special conditions | |
| with st.expander("π Special Conditions"): | |
| for condition in contract["terms"]["special_conditions"]: | |
| st.write(f"β’ {condition}") | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| def initialize_data(): | |
| """Initialize data and models""" | |
| try: | |
| # Initialize embedding model | |
| embedding_model = load_embedding_model() | |
| # Initialize invoice generator | |
| generator = InvoiceGenerator(data_dir="data") | |
| # Ensure we have both contracts and invoices | |
| if not os.path.exists("data/contracts.json") or not os.path.exists("data/invoices.json"): | |
| generator.generate_and_save() | |
| # Load contracts and invoices | |
| contracts = generator.load_contracts() | |
| invoices = generator.load_or_generate_invoices() | |
| if not contracts or not invoices: | |
| st.error("No data found. Generating new data...") | |
| generator.generate_and_save() | |
| contracts = generator.load_contracts() | |
| invoices = generator.load_or_generate_invoices() | |
| # Initialize vector store | |
| vector_store = ContractVectorStore(embedding_model) | |
| for contract in contracts: | |
| vector_store.add_contract_terms(contract) | |
| return contracts, invoices, vector_store | |
| except Exception as e: | |
| st.error(f"Error initializing data: {str(e)}") | |
| st.stop() | |
| def main(): | |
| # Load custom CSS | |
| try: | |
| load_css() | |
| except Exception as e: | |
| st.warning(f"Could not load custom CSS: {str(e)}") | |
| st.title("π Enterprise Pricing Audit Assistant") | |
| try: | |
| # Initialize data and models | |
| with st.spinner('Loading data and initializing models...'): | |
| contracts, invoices, vector_store = initialize_data() | |
| # Convert invoices to DataFrame | |
| invoices_df = pd.DataFrame(invoices) | |
| # Display metrics | |
| display_metrics(invoices_df) | |
| # Display contract selection | |
| selected_contract_id = st.selectbox( | |
| "Select Contract", | |
| options=[c["contract_id"] for c in contracts], | |
| format_func=lambda x: f"{x} - {next(c['client'] for c in contracts if c['contract_id'] == x)}" | |
| ) | |
| # Get selected contract | |
| selected_contract = next(c for c in contracts if c["contract_id"] == selected_contract_id) | |
| # Display contract details | |
| display_contract_details(selected_contract) | |
| # Filter invoices for selected contract | |
| contract_invoices_df = invoices_df[invoices_df['contract_id'] == selected_contract_id] | |
| # Display invoice analysis | |
| st.subheader("π Invoice Analysis") | |
| # Create tabs for different views | |
| tab1, tab2, tab3 = st.tabs(["π Overview", "π Invoice Details", "π Detailed Analysis"]) | |
| with tab1: | |
| # Display summary metrics for the selected contract | |
| total_contract_value = contract_invoices_df['amount_charged'].sum() | |
| total_contract_discrepancy = ( | |
| contract_invoices_df['amount_charged'] - contract_invoices_df['correct_amount'] | |
| ).sum() | |
| error_rate = ( | |
| len(contract_invoices_df[contract_invoices_df['has_error']]) / | |
| len(contract_invoices_df) * 100 | |
| ) | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.metric("Total Contract Value", f"${total_contract_value:,.2f}") | |
| with col2: | |
| st.metric("Total Discrepancy", f"${total_contract_discrepancy:,.2f}") | |
| with col3: | |
| st.metric("Error Rate", f"{error_rate:.1f}%") | |
| # Create visualization | |
| if not contract_invoices_df.empty: | |
| # Prepare data for visualization | |
| contract_invoices_df['error_amount'] = ( | |
| contract_invoices_df['amount_charged'] - | |
| contract_invoices_df['correct_amount'] | |
| ) | |
| # Create scatter plot | |
| fig = go.Figure() | |
| # Add points for correct invoices | |
| correct_invoices = contract_invoices_df[~contract_invoices_df['has_error']] | |
| if not correct_invoices.empty: | |
| fig.add_trace(go.Scatter( | |
| x=correct_invoices['date'], | |
| y=correct_invoices['amount_charged'], | |
| mode='markers', | |
| name='Correct Invoices', | |
| marker=dict(color='green', size=10), | |
| )) | |
| # Add points for incorrect invoices | |
| incorrect_invoices = contract_invoices_df[contract_invoices_df['has_error']] | |
| if not incorrect_invoices.empty: | |
| fig.add_trace(go.Scatter( | |
| x=incorrect_invoices['date'], | |
| y=incorrect_invoices['amount_charged'], | |
| mode='markers', | |
| name='Incorrect Invoices', | |
| marker=dict(color='red', size=10), | |
| )) | |
| fig.update_layout( | |
| title='Invoice Amounts Over Time', | |
| xaxis_title='Date', | |
| yaxis_title='Amount ($)', | |
| hovermode='closest' | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| with tab2: | |
| # Display invoice tables | |
| display_invoice_tables(contract_invoices_df) | |
| with tab3: | |
| # Detailed analysis of incorrect invoices | |
| incorrect_invoices = contract_invoices_df[contract_invoices_df['has_error']] | |
| if not incorrect_invoices.empty: | |
| for _, invoice in incorrect_invoices.iterrows(): | |
| with st.expander(f"Invoice {invoice['invoice_id']} Analysis"): | |
| analysis = analyze_invoice_with_rag( | |
| invoice.to_dict(), | |
| selected_contract, | |
| vector_store | |
| ) | |
| # Display analysis results | |
| st.write("**Discrepancy Amount:**", | |
| f"${analysis['discrepancy']:.2f} " | |
| f"({analysis['discrepancy_percentage']}%)") | |
| st.write("**Relevant Contract Terms:**") | |
| for term in analysis['relevant_terms']: | |
| st.write(f"β’ {term['text']}") | |
| st.write("**Analysis:**") | |
| st.write(analysis['explanation']) | |
| else: | |
| st.info("No pricing discrepancies found for this contract.") | |
| except Exception as e: | |
| st.error(f"An error occurred: {str(e)}") | |
| st.stop() | |
| if __name__ == "__main__": | |
| main() |