Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import json | |
| import os | |
| import requests | |
| import re | |
| from io import StringIO | |
| from fuzzywuzzy import fuzz | |
| from typing import Dict, Any | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.schema import HumanMessage | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import numpy as np | |
| st.set_page_config(page_title="EZOFIS Vendor Onboarding AI Agent", layout="wide") | |
| st.markdown(""" | |
| <style> | |
| .block-card { background: #fff; border-radius: 20px; box-shadow: 0 2px 16px rgba(25,39,64,0.05); padding: 32px 26px 24px 26px; margin-bottom: 24px; } | |
| .step-num { background: #A020F0; color: #fff; border-radius: 999px; padding: 6px 13px; font-weight: 700; margin-right: 14px; font-size: 20px; display: inline-block; vertical-align: middle; } | |
| .stButton>button { background: #A020F0 !important; color: white !important; border-radius: 12px !important; padding: 10px 32px !important; font-weight: 700; border: none !important; font-size: 18px !important; margin-top: 12px !important; } | |
| .stSlider>div>div>div>div { background: #F3F6FB !important; border-radius: 999px; } | |
| h1.ez-header { font-size: 2.5em; font-weight: 900; color: #A020F0; margin-bottom: 0.25em; margin-top: 8px;} | |
| </style> | |
| """, unsafe_allow_html=True) | |
| st.markdown("<h1 class='ez-header'>EZOFIS Vendor Onboarding AI Agent</h1>", unsafe_allow_html=True) | |
| st.markdown("<div style='font-size:20px; margin-bottom:18px; color:#24345C;'>AI-powered, context-aware supplier onboarding for faster compliance and better decisions.</div>", unsafe_allow_html=True) | |
| MODELS = { | |
| "OpenAI GPT-4.1": { | |
| "model": "gpt-4-1106-preview", | |
| "api_env": "OPENAI_API_KEY", | |
| "openai_api_base": None, | |
| }, | |
| "OpenAI GPT-3.5": { | |
| "model": "gpt-3.5-turbo-0125", | |
| "api_env": "OPENAI_API_KEY", | |
| "openai_api_base": None, | |
| }, | |
| "Mistral (OpenRouter)": { | |
| "model": "mistralai/ministral-8b", | |
| "api_env": "OPENROUTER_API_KEY", | |
| "openai_api_base": "https://openrouter.ai/api/v1", | |
| } | |
| } | |
| def get_llm(model_choice): | |
| config = MODELS[model_choice] | |
| api_key = os.getenv(config["api_env"]) | |
| if not api_key: | |
| st.error(f"API key not set: {config['api_env']}") | |
| st.stop() | |
| return ChatOpenAI( | |
| model=config["model"], | |
| openai_api_key=api_key, | |
| openai_api_base=config.get("openai_api_base"), | |
| temperature=0.1, | |
| max_tokens=2500 | |
| ) | |
| def query_corporations_canada(company_name): | |
| url = "https://api-prod.ised-isde.canada.ca/cc/api/v1/corporations" | |
| params = {"search": company_name} | |
| try: | |
| response = requests.get(url, params=params, timeout=8) | |
| if response.status_code == 200: | |
| results = response.json() | |
| return results.get('items', [])[0] if results.get('items') else None | |
| return None | |
| except Exception as e: | |
| return None | |
| def try_best_effort_json_parse(llm_response): | |
| cleaned = re.sub(r',\s*([\]}])', r'\1', llm_response) | |
| try: | |
| return json.loads(cleaned) | |
| except Exception: | |
| result = {} | |
| for field in ["decision", "reason", "red_flags", "missing_documents", "next_steps"]: | |
| pattern = rf'"{field}"\s*:\s*(\[[^\]]*\]|"[^"]*"|[^,}}\n]+)' | |
| match = re.search(pattern, llm_response, re.IGNORECASE | re.DOTALL) | |
| if match: | |
| value = match.group(1).strip() | |
| if value.startswith("["): | |
| try: | |
| result[field] = json.loads(value.replace("\n", "")) | |
| except Exception: | |
| result[field] = [v.strip(' "\'') for v in value.strip('[]').split(",") if v.strip()] | |
| else: | |
| result[field] = value.strip(' "\'') | |
| else: | |
| result[field] = "" | |
| if result.get("decision") or result.get("reason"): | |
| return result | |
| return None | |
| def list_to_str(val): | |
| if isinstance(val, list): | |
| str_vals = [] | |
| for item in val: | |
| if isinstance(item, dict): | |
| if "document" in item: | |
| str_vals.append(str(item["document"])) | |
| elif "reason" in item: | |
| str_vals.append(str(item["reason"])) | |
| else: | |
| str_vals.append(str(item)) | |
| else: | |
| str_vals.append(str(item)) | |
| return ", ".join(str_vals) | |
| return str(val) if val is not None else "" | |
| col1, col2 = st.columns([2, 4]) | |
| with col1: | |
| st.markdown("<span class='step-num'>1</span> <b>Upload Supplier/Product Checklist CSV</b>", unsafe_allow_html=True) | |
| csv_file = st.file_uploader("CSV with product/service and required docs", type=["csv"], key="supplier_csv", label_visibility="collapsed") | |
| if csv_file is not None: | |
| df = pd.read_csv(StringIO(csv_file.getvalue().decode("utf-8"))) | |
| st.success(f"Loaded {len(df)} checklist records.") | |
| else: | |
| df = None | |
| st.markdown("<span class='step-num'>2</span> <b>Select Model</b>", unsafe_allow_html=True) | |
| mdl = st.selectbox("LLM Model", list(MODELS.keys()), key="llm_model", index=0) | |
| st.markdown("<span class='step-num'>3</span> <b>Upload Supplier Q&A JSON</b>", unsafe_allow_html=True) | |
| json_file = st.file_uploader("Supplier Onboarding Q&A JSON", type=["json"], key="supplier_json", label_visibility="collapsed") | |
| supplier_json = None | |
| if json_file: | |
| try: | |
| supplier_json = json.load(json_file) | |
| st.success("Loaded supplier onboarding form JSON.") | |
| except Exception as e: | |
| st.error("Invalid JSON file uploaded.") | |
| # --- Strictness Sliders --- | |
| st.markdown("<span class='step-num'>4</span> <b>Agent Strictness Controls</b>", unsafe_allow_html=True) | |
| doc_match_threshold = st.slider("Document Match Strictness (Semantic Similarity)", 0.5, 0.95, 0.65, 0.01, | |
| help="How closely a field or file must match the checklist requirement. Higher = stricter.") | |
| mandatory_doc_pct = st.slider("Mandatory Document Strictness (%)", 20, 100, 100, 1, | |
| help="Minimum percentage of mandatory docs required for onboarding.") | |
| field_value_strict = st.slider("Field Value Strictness (0=Any, 1=Only Files/Numbers)", 0, 1, 0, 1, | |
| help="Require files/formal values (1) or allow 'Yes', 'N/A', etc. (0).") | |
| compliance_strict = st.slider("Compliance Registry Strictness (0=Allow, 1=Reject)", 0, 1, 1, 1, | |
| help="Reject if company not found in Canada registry (1), or allow with warning (0).") | |
| st.markdown("<span class='step-num'>5</span> <b>Agent Instructions</b>", unsafe_allow_html=True) | |
| user_editable_instructions = """- For each compliance requirement, you are provided the best-matching (key, value) field pair from the supplier's data (determined by AI semantic similarity, not just spelling). | |
| - If the value is a file (e.g., .pdf or .docx), treat this as strong evidence and DO NOT mark as missing unless there is a clear reason. | |
| - Only mark as missing if nothing plausible is provided anywhere in the data, or if the evidence is clearly invalid.""" | |
| agent_instruction = st.text_area("Edit agent instruction prompt:", value=user_editable_instructions, height=130, key="agent_instruction") | |
| with col2: | |
| st.markdown("<span class='step-num'>6</span> <b>Run Supplier Onboarding Agent</b>", unsafe_allow_html=True) | |
| run_agent = st.button("Run Supplier Onboarding Agent", type="primary") | |
| if run_agent: | |
| if df is None or supplier_json is None: | |
| st.warning("Please upload both checklist CSV and supplier JSON.") | |
| else: | |
| with st.spinner("Processing and analyzing supplier onboarding..."): | |
| def get_embedder(): | |
| return SentenceTransformer('all-MiniLM-L6-v2') | |
| embedder = get_embedder() | |
| def flatten_json(y: Dict[str, Any], prefix='') -> Dict[str, Any]: | |
| out = {} | |
| for k, v in y.items(): | |
| new_key = f"{prefix}.{k}" if prefix else k | |
| if isinstance(v, dict): | |
| out.update(flatten_json(v, new_key)) | |
| else: | |
| out[new_key] = v | |
| return out | |
| supplier_flat = flatten_json(supplier_json) | |
| if "Supporting Documents (File Uploads)" in supplier_json: | |
| for k, v in supplier_json["Supporting Documents (File Uploads)"].items(): | |
| supplier_flat[f"Supporting Documents (File Uploads) - {k}"] = v | |
| def extract_main_info(flat_data): | |
| info = {} | |
| for k, v in flat_data.items(): | |
| if 'company name' in k.lower(): | |
| info['company_name'] = v | |
| if 'country' in k.lower(): | |
| info['country'] = v | |
| if 'product' in k.lower(): | |
| info['product'] = v | |
| if 'region' in k.lower(): | |
| info['region'] = v | |
| return info | |
| main_info = extract_main_info(supplier_flat) | |
| df['row_context'] = df.apply(lambda row: f"{row.get('Product_or_service_List','')} {row.get('Regions_Where_This_Document_Is_Required','')} {row.get('Document_Checklist','')}", axis=1) | |
| def get_csv_embeddings(contexts): | |
| return embedder.encode(list(contexts), normalize_embeddings=True) | |
| csv_embeddings = get_csv_embeddings(df['row_context']) | |
| supplier_query = f"{main_info.get('product','')} {main_info.get('region','')} {main_info.get('country','')}" | |
| query_emb = embedder.encode([supplier_query], normalize_embeddings=True) | |
| sim_scores = cosine_similarity(query_emb, csv_embeddings)[0] | |
| top_k = min(12, len(df)) | |
| top_idxs = sim_scores.argsort()[-top_k:][::-1] | |
| filtered_rows = df.iloc[top_idxs].copy() | |
| st.markdown("#### Filtered Checklist Used") | |
| st.dataframe(filtered_rows, use_container_width=True) | |
| supplier_pairs = [f"{k}: {v}" for k, v in supplier_flat.items() if v is not None and str(v).strip() != ""] | |
| def get_supplier_embeddings(supplier_pairs): | |
| return embedder.encode(list(supplier_pairs), normalize_embeddings=True) | |
| supplier_embeddings = get_supplier_embeddings(supplier_pairs) | |
| findings = [] | |
| for idx, row in filtered_rows.iterrows(): | |
| doc_name = row.get('Document_Checklist', '') | |
| requirement = f"{row.get('Product_or_service_List','')} {row.get('Regions_Where_This_Document_Is_Required','')} {row.get('Document_Checklist','')}" | |
| requirement_emb = embedder.encode([requirement], normalize_embeddings=True) | |
| sim_scores = cosine_similarity(requirement_emb, supplier_embeddings)[0] | |
| best_idx = np.argmax(sim_scores) | |
| best_match_text = supplier_pairs[best_idx] | |
| best_score = sim_scores[best_idx] | |
| is_provided = best_score >= doc_match_threshold | |
| if field_value_strict == 1: | |
| file_like = bool(re.search(r'\.(pdf|docx?|xls|csv|jpg|jpeg|png)$', str(best_match_text).lower())) | |
| numeric_like = bool(re.match(r'^\d+(\.\d+)?$', str(best_match_text))) | |
| is_provided = is_provided and (file_like or numeric_like) | |
| if ": " in best_match_text: | |
| matched_key, matched_value = best_match_text.split(": ", 1) | |
| else: | |
| matched_key, matched_value = best_match_text, "" | |
| findings.append({ | |
| "document": doc_name, | |
| "mandatory": row.get('Mandatory_Yes_or_No', '').strip().lower() == 'yes', | |
| "matched_field": matched_key, | |
| "value": matched_value, | |
| "match_score": float(best_score), | |
| "status": "Provided" if is_provided else "Missing" | |
| }) | |
| n_mandatory = sum(1 for f in findings if f["mandatory"]) | |
| n_mandatory_provided = sum(1 for f in findings if f["mandatory"] and f["status"] == "Provided") | |
| pct_mandatory = 100 * n_mandatory_provided / n_mandatory if n_mandatory else 100 | |
| company_registration = None | |
| if main_info.get('country', '').lower() in ["canada", "ca"]: | |
| company_name = main_info.get('company_name') | |
| if company_name: | |
| company_registration = query_corporations_canada(company_name) | |
| if company_registration: | |
| registration_summary = { | |
| "corporation_number": company_registration.get("corporationNumber"), | |
| "legal_name": company_registration.get("corporationName"), | |
| "status": company_registration.get("corporationStatus"), | |
| "incorporation_date": company_registration.get("incorporationDate"), | |
| "type": company_registration.get("corporationType"), | |
| } | |
| else: | |
| registration_summary = {"corporation_number": None, "legal_name": None, "status": None} | |
| early_reject = False | |
| reasons = [] | |
| if pct_mandatory < mandatory_doc_pct: | |
| early_reject = True | |
| reasons.append(f"Only {pct_mandatory:.0f}% of mandatory docs provided, below {mandatory_doc_pct}% required.") | |
| if compliance_strict and (registration_summary.get("corporation_number") is None): | |
| early_reject = True | |
| reasons.append("Company not found in Corporations Canada registry.") | |
| persona_instruction = "You are a senior supplier onboarding analyst.\n\n" | |
| llm_return_format = """ | |
| Return JSON: | |
| { | |
| "decision": "ONBOARDED" | "REJECTED" | "PENDING", | |
| "reason": "...", | |
| "red_flags": [ ... ], | |
| "missing_documents": [ ... ], | |
| "next_steps": [ ... ] | |
| } | |
| """ | |
| if early_reject: | |
| agent_json = { | |
| "decision": "REJECTED", | |
| "reason": "; ".join(reasons), | |
| "red_flags": reasons, | |
| "missing_documents": [f["document"] for f in findings if f["mandatory"] and f["status"] == "Missing"], | |
| "next_steps": ["Request all missing mandatory documents and registry proof before proceeding."] | |
| } | |
| else: | |
| llm_prompt = f""" | |
| {persona_instruction} | |
| {agent_instruction} | |
| Supplier main info: {json.dumps(main_info, indent=2)} | |
| Checklist document matching and findings (each with best-matching field/value from supplier data, by semantic similarity): | |
| {json.dumps(findings, indent=2)} | |
| Checklist (filtered for this supplier): | |
| {json.dumps(filtered_rows.to_dict('records'), indent=2)} | |
| External Registry (Corporations Canada) lookup: | |
| {json.dumps(registration_summary, indent=2)} | |
| Strictness settings: | |
| - Document match threshold: {doc_match_threshold} | |
| - Mandatory docs required: {mandatory_doc_pct}% | |
| - Field value strictness: {field_value_strict} | |
| - Compliance registry strictness: {compliance_strict} | |
| {llm_return_format} | |
| """ | |
| try: | |
| agent_llm = get_llm(mdl) | |
| result = agent_llm([ | |
| HumanMessage(content=llm_prompt) | |
| ]) | |
| llm_response = result.content | |
| agent_json = try_best_effort_json_parse(llm_response) | |
| except Exception as e: | |
| agent_json = None | |
| if agent_json: | |
| st.markdown(""" | |
| <div class='block-card'> | |
| <table style='width:100%;font-size:1.08em;'> | |
| <tr><th colspan='2' style='text-align:left;font-size:1.2em;'>Conclusion</th></tr> | |
| <tr> | |
| <td style='font-weight:700;color:#A020F0;'>Decision</td> | |
| <td style='color:{};font-weight:700;'>{}</td> | |
| </tr> | |
| <tr> | |
| <td style='font-weight:700;'>Reason</td> | |
| <td>{}</td> | |
| </tr> | |
| <tr> | |
| <td style='font-weight:700;'>Red Flags</td> | |
| <td>{}</td> | |
| </tr> | |
| <tr> | |
| <td style='font-weight:700;'>Missing Documents</td> | |
| <td>{}</td> | |
| </tr> | |
| <tr> | |
| <td style='font-weight:700;'>Next Steps</td> | |
| <td>{}</td> | |
| </tr> | |
| </table> | |
| </div> | |
| """.format( | |
| "#E53935" if agent_json.get("decision", "").upper() == "REJECTED" else "#388E3C", | |
| agent_json.get("decision", "N/A"), | |
| agent_json.get("reason", "—"), | |
| list_to_str(agent_json.get("red_flags", [])), | |
| list_to_str(agent_json.get("missing_documents", [])), | |
| list_to_str(agent_json.get("next_steps", [])), | |
| ), unsafe_allow_html=True) | |
| else: | |
| st.warning("Agent did not provide a structured decision.") | |
| st.markdown("#### Full JSON Output") | |
| try: | |
| st.json(agent_json if agent_json else llm_response) | |
| except Exception: | |
| st.code(llm_response) | |
| st.markdown("#### Matching Details (Relevant Rows Only)") | |
| relevant_findings = pd.DataFrame(findings) | |
| st.dataframe(relevant_findings, use_container_width=True) | |
| if company_registration: | |
| st.markdown("#### Corporations Canada Registration Lookup") | |
| st.json(company_registration) | |