Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| from datetime import datetime | |
| import torch | |
| import torch.nn.functional as F | |
| from sentence_transformers import SentenceTransformer, util | |
| from collections import Counter | |
| import re | |
| # Initialize models globally | |
| print("Loading models...") | |
| try: | |
| # Replace with your actual model when uploaded to HuggingFace | |
| sbert_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| print("SBERT model loaded successfully") | |
| except Exception as e: | |
| print(f"Error loading SBERT model: {e}") | |
| sbert_model = None | |
| def missing_field_score_v2(product_name, quantity, delivery_date, filename, company_name=""): | |
| """Calculate missing field score exactly like the original model""" | |
| score = 0 | |
| name = str(product_name).strip().lower() | |
| words = name.split() | |
| if not name: | |
| score += 2 | |
| elif len(words) < 3: | |
| score += 1 | |
| try: | |
| qty = float(quantity) if quantity else 0 | |
| if pd.isna(qty) or qty <= 0: | |
| score += 2 | |
| except: | |
| score += 2 | |
| if pd.isna(delivery_date) or not str(delivery_date).strip(): | |
| score += 1 | |
| else: | |
| try: | |
| delivery_dt = pd.to_datetime(delivery_date) | |
| days_to_delivery = (delivery_dt - datetime.now()).days | |
| if days_to_delivery <= 0: | |
| score += 1 | |
| except: | |
| score += 1 | |
| if not str(filename).strip(): | |
| score += 0.5 | |
| if not str(company_name).strip(): | |
| score += 0.5 | |
| return score / 8 | |
| def get_filename_encoding(filename): | |
| """Encode filename similar to original model""" | |
| if pd.isna(filename) or not str(filename).strip(): | |
| return 2.5 # Moderate for missing | |
| filename_str = str(filename).lower() | |
| # Extract filename prefix before first underscore or dot | |
| if '_' in filename_str: | |
| prefix = filename_str.split('_')[0] | |
| else: | |
| prefix = filename_str.split('.')[0] | |
| # Create balanced encoding based on filename prefix | |
| # High risk files (3.0+ values) | |
| if prefix.startswith(('invoice', 'txn', 'mgt')): | |
| return 3.2 # High risk | |
| elif prefix.startswith(('manzillglobe', 'daljit')): | |
| return 3.5 # High risk | |
| # Low risk files (0-2.0 values) | |
| elif prefix.startswith(('order', 'po')): | |
| return 0.8 # Low risk | |
| elif prefix.startswith(('ref', 'manzill')): | |
| return 1.2 # Low risk | |
| else: | |
| return 2.0 # Moderate for unknown prefixes | |
| def delivery_lag_flag(date_str): | |
| """Check if delivery is urgent""" | |
| try: | |
| delivery_date = pd.to_datetime(date_str) | |
| return int((delivery_date - datetime.now()).days <= 3) | |
| except: | |
| return 1 | |
| def compute_semantic_similarity(product_name, sku_database=None): | |
| """Compute semantic similarity with SKU database""" | |
| if not sbert_model or not product_name.strip(): | |
| return 0.0, "", "", 0.0 | |
| # Default SKU database for demo | |
| if not sku_database: | |
| sku_database = [ | |
| {"SKU_Code": "STL001", "Product_Name": "High-quality steel bolts M8x50"}, | |
| {"SKU_Code": "LED001", "Product_Name": "Premium LED lights 12V"}, | |
| {"SKU_Code": "PLT001", "Product_Name": "Industrial plastic sheets"}, | |
| {"SKU_Code": "WHE001", "Product_Name": "Heavy duty wheels 200mm"}, | |
| {"SKU_Code": "ELE001", "Product_Name": "Electronic components kit"} | |
| ] | |
| try: | |
| # Encode texts | |
| po_embedding = sbert_model.encode([product_name]) | |
| sku_texts = [item["Product_Name"] for item in sku_database] | |
| sku_embeddings = sbert_model.encode(sku_texts) | |
| # Calculate similarities | |
| similarities = util.cos_sim(po_embedding, sku_embeddings)[0] | |
| # Find best match | |
| best_idx = similarities.argmax().item() | |
| best_similarity = similarities[best_idx].item() | |
| matched_sku_code = sku_database[best_idx]["SKU_Code"] | |
| matched_sku_name = sku_database[best_idx]["Product_Name"] | |
| return best_similarity, matched_sku_code, matched_sku_name, similarities | |
| except Exception as e: | |
| print(f"Error in semantic similarity: {e}") | |
| return 0.0, "", "", 0.0 | |
| def predict_po_risk(product_name, quantity, delivery_date, filename, company_name=""): | |
| """ | |
| Main prediction function matching your original model logic | |
| """ | |
| try: | |
| # Calculate features exactly like your model | |
| missing_score = missing_field_score_v2(product_name, quantity, delivery_date, filename, company_name) | |
| # Semantic similarity | |
| cosine_similarity, matched_sku_code, matched_sku_name, similarities = compute_semantic_similarity(product_name) | |
| # Calculate ambiguity gap (difference between top 2 matches) | |
| if hasattr(similarities, '__len__') and len(similarities) >= 2: | |
| sorted_sims = sorted(similarities, reverse=True) | |
| ambiguity_gap = float(sorted_sims[0] - sorted_sims[1]) | |
| else: | |
| ambiguity_gap = 0.0 | |
| # Filename encoding | |
| filename_encoding = get_filename_encoding(filename) | |
| # Delivery lag | |
| delivery_lag = delivery_lag_flag(delivery_date) | |
| # Simple semantic signal (PCA would normally be applied here) | |
| semantic_signal = cosine_similarity - 0.5 # Normalized around 0 | |
| # Token rarity (simplified - in real model this uses corpus statistics) | |
| words = str(product_name).lower().split() | |
| description_rarity = 1.0 / (len(words) + 1) if words else 1.0 | |
| # Combine features for risk prediction (simplified rule-based) | |
| # In your actual model, this would use the trained XGBoost model | |
| risk_factors = [ | |
| missing_score * 3.0, # Weight missing fields heavily | |
| (1.0 - cosine_similarity) * 2.0, # Low similarity = higher risk | |
| filename_encoding / 4.0, # Normalize filename score | |
| delivery_lag * 1.5, # Urgent delivery increases risk | |
| description_rarity * 1.0, # Rare descriptions are riskier | |
| ] | |
| risk_score = np.mean(risk_factors) | |
| # Determine risk level | |
| if risk_score > 0.7: | |
| predicted_risk = "High" | |
| confidence = min(0.95, 0.6 + risk_score * 0.35) | |
| elif risk_score > 0.4: | |
| predicted_risk = "Medium" | |
| confidence = 0.75 | |
| else: | |
| predicted_risk = "Low" | |
| confidence = min(0.95, 0.85 - risk_score * 0.3) | |
| # Return detailed results | |
| return { | |
| "π― Risk Level": predicted_risk, | |
| "π Risk Score": f"{risk_score:.3f}", | |
| "π² Confidence": f"{confidence:.3f}", | |
| "β Missing Field Score": f"{missing_score:.3f}", | |
| "π Cosine Similarity": f"{cosine_similarity:.3f}", | |
| "π Filename Risk Score": f"{filename_encoding:.1f}", | |
| "β‘ Delivery Urgency": "Yes" if delivery_lag else "No", | |
| "π·οΈ Matched SKU Code": matched_sku_code or "No match", | |
| "π Matched SKU Name": matched_sku_name or "No match", | |
| "π Semantic Signal": f"{semantic_signal:.3f}", | |
| "π€ Description Rarity": f"{description_rarity:.3f}" | |
| } | |
| except Exception as e: | |
| return {"β Error": f"Prediction failed: {str(e)}"} | |
| # Create Gradio interface | |
| with gr.Blocks(title="PO Risk Validator", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π Purchase Order Risk Validator") | |
| gr.Markdown("## AI-powered analysis to assess PO risk using semantic matching and XGBoost prediction") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Enter PO Details") | |
| product_name = gr.Textbox( | |
| label="Product Name", | |
| placeholder="e.g., High-quality steel bolts M8x50", | |
| info="Detailed product description helps improve accuracy", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| quantity = gr.Number( | |
| label="Quantity", | |
| value=1, | |
| minimum=0, | |
| info="Order quantity" | |
| ) | |
| delivery_date = gr.Textbox( | |
| label="Delivery Date", | |
| placeholder="2025-08-15", | |
| info="Expected delivery date (YYYY-MM-DD)" | |
| ) | |
| filename = gr.Textbox( | |
| label="Document Filename", | |
| placeholder="invoice_001.pdf", | |
| info="Original document filename" | |
| ) | |
| company_name = gr.Textbox( | |
| label="Company Name (Optional)", | |
| placeholder="SteelCorp Ltd.", | |
| info="Supplier company name" | |
| ) | |
| predict_btn = gr.Button("π Analyze PO Risk", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Risk Assessment Results") | |
| output = gr.JSON(label="Analysis Results", show_label=False) | |
| gr.Markdown("### βΉοΈ Understanding the Results") | |
| gr.Markdown(""" | |
| - **Risk Level**: Overall assessment (Low/Medium/High) | |
| - **Risk Score**: Numerical risk value (0-1, higher = riskier) | |
| - **Confidence**: Model confidence in prediction | |
| - **Missing Field Score**: Penalty for incomplete data | |
| - **Cosine Similarity**: Semantic match with SKU database | |
| - **Filename Risk Score**: Risk based on document type | |
| - **Delivery Urgency**: Whether delivery is within 3 days | |
| """) | |
| # Examples section | |
| gr.Markdown("### π― Try These Examples") | |
| examples = [ | |
| ["High-quality steel bolts M8x50", 100, "2025-08-15", "order_ref_001.pdf", "SteelCorp Ltd"], | |
| ["", 0, "", "invoice_urgent.pdf", ""], # High risk example | |
| ["Premium LED lights 12V", 50, "2025-09-01", "po_standard_123.pdf", "LightTech Inc"], | |
| ["Industrial grade components", 25, "2025-07-30", "txn_immediate.pdf", "QuickSupply Co"], | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[product_name, quantity, delivery_date, filename, company_name], | |
| outputs=output, | |
| fn=predict_po_risk, | |
| cache_examples=True, | |
| label="Sample PO Data" | |
| ) | |
| # Connect the button | |
| predict_btn.click( | |
| fn=predict_po_risk, | |
| inputs=[product_name, quantity, delivery_date, filename, company_name], | |
| outputs=output | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("### π About This Model") | |
| gr.Markdown(""" | |
| This demo showcases a simplified version of the PO Risk Validator. The full production model includes: | |
| - Fine-tuned Sentence-BERT for semantic product matching | |
| - XGBoost classifier trained on historical PO data | |
| - Advanced feature engineering and PCA dimensionality reduction | |
| - Real-time SKU database integration | |
| """) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |