import gradio as gr import pandas as pd import numpy as np from datetime import datetime import torch import torch.nn.functional as F from sentence_transformers import SentenceTransformer, util from collections import Counter import re # Initialize models globally print("Loading models...") try: # Replace with your actual model when uploaded to HuggingFace sbert_model = SentenceTransformer("all-MiniLM-L6-v2") print("SBERT model loaded successfully") except Exception as e: print(f"Error loading SBERT model: {e}") sbert_model = None def missing_field_score_v2(product_name, quantity, delivery_date, filename, company_name=""): """Calculate missing field score exactly like the original model""" score = 0 name = str(product_name).strip().lower() words = name.split() if not name: score += 2 elif len(words) < 3: score += 1 try: qty = float(quantity) if quantity else 0 if pd.isna(qty) or qty <= 0: score += 2 except: score += 2 if pd.isna(delivery_date) or not str(delivery_date).strip(): score += 1 else: try: delivery_dt = pd.to_datetime(delivery_date) days_to_delivery = (delivery_dt - datetime.now()).days if days_to_delivery <= 0: score += 1 except: score += 1 if not str(filename).strip(): score += 0.5 if not str(company_name).strip(): score += 0.5 return score / 8 def get_filename_encoding(filename): """Encode filename similar to original model""" if pd.isna(filename) or not str(filename).strip(): return 2.5 # Moderate for missing filename_str = str(filename).lower() # Extract filename prefix before first underscore or dot if '_' in filename_str: prefix = filename_str.split('_')[0] else: prefix = filename_str.split('.')[0] # Create balanced encoding based on filename prefix # High risk files (3.0+ values) if prefix.startswith(('invoice', 'txn', 'mgt')): return 3.2 # High risk elif prefix.startswith(('manzillglobe', 'daljit')): return 3.5 # High risk # Low risk files (0-2.0 values) elif prefix.startswith(('order', 'po')): return 0.8 # Low risk elif prefix.startswith(('ref', 'manzill')): return 1.2 # Low risk else: return 2.0 # Moderate for unknown prefixes def delivery_lag_flag(date_str): """Check if delivery is urgent""" try: delivery_date = pd.to_datetime(date_str) return int((delivery_date - datetime.now()).days <= 3) except: return 1 def compute_semantic_similarity(product_name, sku_database=None): """Compute semantic similarity with SKU database""" if not sbert_model or not product_name.strip(): return 0.0, "", "", 0.0 # Default SKU database for demo if not sku_database: sku_database = [ {"SKU_Code": "STL001", "Product_Name": "High-quality steel bolts M8x50"}, {"SKU_Code": "LED001", "Product_Name": "Premium LED lights 12V"}, {"SKU_Code": "PLT001", "Product_Name": "Industrial plastic sheets"}, {"SKU_Code": "WHE001", "Product_Name": "Heavy duty wheels 200mm"}, {"SKU_Code": "ELE001", "Product_Name": "Electronic components kit"} ] try: # Encode texts po_embedding = sbert_model.encode([product_name]) sku_texts = [item["Product_Name"] for item in sku_database] sku_embeddings = sbert_model.encode(sku_texts) # Calculate similarities similarities = util.cos_sim(po_embedding, sku_embeddings)[0] # Find best match best_idx = similarities.argmax().item() best_similarity = similarities[best_idx].item() matched_sku_code = sku_database[best_idx]["SKU_Code"] matched_sku_name = sku_database[best_idx]["Product_Name"] return best_similarity, matched_sku_code, matched_sku_name, similarities except Exception as e: print(f"Error in semantic similarity: {e}") return 0.0, "", "", 0.0 def predict_po_risk(product_name, quantity, delivery_date, filename, company_name=""): """ Main prediction function matching your original model logic """ try: # Calculate features exactly like your model missing_score = missing_field_score_v2(product_name, quantity, delivery_date, filename, company_name) # Semantic similarity cosine_similarity, matched_sku_code, matched_sku_name, similarities = compute_semantic_similarity(product_name) # Calculate ambiguity gap (difference between top 2 matches) if hasattr(similarities, '__len__') and len(similarities) >= 2: sorted_sims = sorted(similarities, reverse=True) ambiguity_gap = float(sorted_sims[0] - sorted_sims[1]) else: ambiguity_gap = 0.0 # Filename encoding filename_encoding = get_filename_encoding(filename) # Delivery lag delivery_lag = delivery_lag_flag(delivery_date) # Simple semantic signal (PCA would normally be applied here) semantic_signal = cosine_similarity - 0.5 # Normalized around 0 # Token rarity (simplified - in real model this uses corpus statistics) words = str(product_name).lower().split() description_rarity = 1.0 / (len(words) + 1) if words else 1.0 # Combine features for risk prediction (simplified rule-based) # In your actual model, this would use the trained XGBoost model risk_factors = [ missing_score * 3.0, # Weight missing fields heavily (1.0 - cosine_similarity) * 2.0, # Low similarity = higher risk filename_encoding / 4.0, # Normalize filename score delivery_lag * 1.5, # Urgent delivery increases risk description_rarity * 1.0, # Rare descriptions are riskier ] risk_score = np.mean(risk_factors) # Determine risk level if risk_score > 0.7: predicted_risk = "High" confidence = min(0.95, 0.6 + risk_score * 0.35) elif risk_score > 0.4: predicted_risk = "Medium" confidence = 0.75 else: predicted_risk = "Low" confidence = min(0.95, 0.85 - risk_score * 0.3) # Return detailed results return { "đŸŽ¯ Risk Level": predicted_risk, "📊 Risk Score": f"{risk_score:.3f}", "🎲 Confidence": f"{confidence:.3f}", "❌ Missing Field Score": f"{missing_score:.3f}", "🔍 Cosine Similarity": f"{cosine_similarity:.3f}", "📂 Filename Risk Score": f"{filename_encoding:.1f}", "⚡ Delivery Urgency": "Yes" if delivery_lag else "No", "đŸˇī¸ Matched SKU Code": matched_sku_code or "No match", "📝 Matched SKU Name": matched_sku_name or "No match", "🔄 Semantic Signal": f"{semantic_signal:.3f}", "🔤 Description Rarity": f"{description_rarity:.3f}" } except Exception as e: return {"❌ Error": f"Prediction failed: {str(e)}"} # Create Gradio interface with gr.Blocks(title="PO Risk Validator", theme=gr.themes.Soft()) as demo: gr.Markdown("# 📋 Purchase Order Risk Validator") gr.Markdown("## AI-powered analysis to assess PO risk using semantic matching and XGBoost prediction") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 📝 Enter PO Details") product_name = gr.Textbox( label="Product Name", placeholder="e.g., High-quality steel bolts M8x50", info="Detailed product description helps improve accuracy", lines=2 ) with gr.Row(): quantity = gr.Number( label="Quantity", value=1, minimum=0, info="Order quantity" ) delivery_date = gr.Textbox( label="Delivery Date", placeholder="2025-08-15", info="Expected delivery date (YYYY-MM-DD)" ) filename = gr.Textbox( label="Document Filename", placeholder="invoice_001.pdf", info="Original document filename" ) company_name = gr.Textbox( label="Company Name (Optional)", placeholder="SteelCorp Ltd.", info="Supplier company name" ) predict_btn = gr.Button("🔍 Analyze PO Risk", variant="primary", size="lg") with gr.Column(scale=1): gr.Markdown("### 📊 Risk Assessment Results") output = gr.JSON(label="Analysis Results", show_label=False) gr.Markdown("### â„šī¸ Understanding the Results") gr.Markdown(""" - **Risk Level**: Overall assessment (Low/Medium/High) - **Risk Score**: Numerical risk value (0-1, higher = riskier) - **Confidence**: Model confidence in prediction - **Missing Field Score**: Penalty for incomplete data - **Cosine Similarity**: Semantic match with SKU database - **Filename Risk Score**: Risk based on document type - **Delivery Urgency**: Whether delivery is within 3 days """) # Examples section gr.Markdown("### đŸŽ¯ Try These Examples") examples = [ ["High-quality steel bolts M8x50", 100, "2025-08-15", "order_ref_001.pdf", "SteelCorp Ltd"], ["", 0, "", "invoice_urgent.pdf", ""], # High risk example ["Premium LED lights 12V", 50, "2025-09-01", "po_standard_123.pdf", "LightTech Inc"], ["Industrial grade components", 25, "2025-07-30", "txn_immediate.pdf", "QuickSupply Co"], ] gr.Examples( examples=examples, inputs=[product_name, quantity, delivery_date, filename, company_name], outputs=output, fn=predict_po_risk, cache_examples=True, label="Sample PO Data" ) # Connect the button predict_btn.click( fn=predict_po_risk, inputs=[product_name, quantity, delivery_date, filename, company_name], outputs=output ) gr.Markdown("---") gr.Markdown("### 🚀 About This Model") gr.Markdown(""" This demo showcases a simplified version of the PO Risk Validator. The full production model includes: - Fine-tuned Sentence-BERT for semantic product matching - XGBoost classifier trained on historical PO data - Advanced feature engineering and PCA dimensionality reduction - Real-time SKU database integration """) # Launch the app if __name__ == "__main__": demo.launch(share=True)