from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from typing import Optional, List, Dict import pandas as pd import numpy as np from datetime import datetime from huggingface_hub import InferenceClient import requests import os from collections import Counter import re app = FastAPI( title="PO Risk Validator API", description="AI-powered Purchase Order risk assessment API using HuggingFace models", version="1.0.0" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize HuggingFace clients HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN") # Set your HF token as environment variable # Replace with your actual model IDs once uploaded SBERT_MODEL_ID = "your-username/po-validator-sbert" XGBOOST_MODEL_ID = "your-username/po-validator-xgboost" # Initialize inference clients sbert_client = InferenceClient(model=SBERT_MODEL_ID, token=HF_TOKEN) class POItem(BaseModel): product_name: str = Field(..., description="Product description") quantity: float = Field(..., gt=0, description="Order quantity") delivery_date: str = Field(..., description="Expected delivery date (YYYY-MM-DD)") filename: str = Field(..., description="Source document filename") company_name: Optional[str] = Field(None, description="Supplier company name") class POBatch(BaseModel): items: List[POItem] sku_database: Optional[List[Dict]] = Field(None, description="SKU reference database") class RiskPrediction(BaseModel): item_index: int product_name: str predicted_risk: str risk_score: float confidence: float matched_sku_code: Optional[str] = None matched_sku_name: Optional[str] = None cosine_similarity: Optional[float] = None missing_field_score: float semantic_signal: Optional[float] = None class BatchPredictionResponse(BaseModel): predictions: List[RiskPrediction] summary: Dict[str, int] def missing_field_score_v2(item: POItem) -> float: """Calculate missing field score""" score = 0 name = str(item.product_name).strip().lower() words = name.split() if not name: score += 2 elif len(words) < 3: score += 1 if item.quantity <= 0: score += 2 if not item.delivery_date: score += 1 else: try: delivery_dt = pd.to_datetime(item.delivery_date) days_to_delivery = (delivery_dt - datetime.now()).days if days_to_delivery <= 0: score += 1 except: score += 1 if not str(item.filename).strip(): score += 0.5 if not str(item.company_name or "").strip(): score += 0.5 return score / 8 def get_semantic_similarity(product_text: str, sku_texts: List[str]) -> Dict: """Get semantic similarity using HuggingFace SBERT model""" try: # Use HuggingFace Inference API for embeddings response = sbert_client.feature_extraction([product_text] + sku_texts) if len(response) < 2: return {"similarity": 0.0, "matched_index": 0, "matched_sku": ""} # Calculate cosine similarity product_embedding = np.array(response[0]) sku_embeddings = np.array(response[1:]) # Compute cosine similarities similarities = [] for sku_emb in sku_embeddings: sim = np.dot(product_embedding, sku_emb) / ( np.linalg.norm(product_embedding) * np.linalg.norm(sku_emb) ) similarities.append(sim) best_match_idx = np.argmax(similarities) best_similarity = similarities[best_match_idx] return { "similarity": float(best_similarity), "matched_index": int(best_match_idx), "similarities": similarities } except Exception as e: print(f"Error in semantic similarity: {e}") return {"similarity": 0.0, "matched_index": 0, "similarities": []} def predict_risk_hf(item: POItem, sku_database: Optional[List[Dict]] = None) -> RiskPrediction: """Predict risk using HuggingFace models""" # Calculate features missing_score = missing_field_score_v2(item) # Semantic matching if SKU database provided matched_sku_code = None matched_sku_name = None cosine_similarity = 0.0 semantic_signal = 0.0 if sku_database: sku_texts = [sku.get("Product_Name", "") for sku in sku_database] sku_codes = [sku.get("SKU_Code", "") for sku in sku_database] sem_result = get_semantic_similarity(item.product_name, sku_texts) cosine_similarity = sem_result["similarity"] matched_idx = sem_result["matched_index"] if matched_idx < len(sku_database): matched_sku_code = sku_codes[matched_idx] matched_sku_name = sku_texts[matched_idx] # Simple semantic signal calculation semantic_signal = cosine_similarity - 0.5 # Normalized around 0 # Filename encoding filename_score = 0.0 if item.filename: filename_str = str(item.filename).lower() if filename_str.startswith(('invoice', 'txn', 'mgt', 'manzillglobe', 'daljit')): filename_score = 3.0 # High risk elif filename_str.startswith(('order', 'po', 'ref', 'manzill')): filename_score = 1.0 # Low risk else: filename_score = 2.0 # Medium risk # Simple rule-based risk prediction (replace with actual XGBoost model call) risk_factors = [ missing_score * 2, # Weight missing fields heavily (1.0 - cosine_similarity) if sku_database else 0.5, # Low similarity = higher risk filename_score / 4.0, # Normalize filename score ] risk_score = np.mean(risk_factors) # Determine risk level if risk_score > 0.6: predicted_risk = "High" confidence = min(0.95, 0.6 + risk_score * 0.4) elif risk_score > 0.3: predicted_risk = "Medium" confidence = 0.7 else: predicted_risk = "Low" confidence = min(0.95, 0.8 - risk_score * 0.3) return RiskPrediction( item_index=0, # Will be set by caller product_name=item.product_name, predicted_risk=predicted_risk, risk_score=round(risk_score, 3), confidence=round(confidence, 3), matched_sku_code=matched_sku_code, matched_sku_name=matched_sku_name, cosine_similarity=round(cosine_similarity, 3) if cosine_similarity else None, missing_field_score=round(missing_score, 3), semantic_signal=round(semantic_signal, 3) if semantic_signal else None ) @app.get("/") async def root(): return { "message": "PO Risk Validator API", "version": "1.0.0", "endpoints": [ "/predict - Single PO prediction", "/predict/batch - Batch PO predictions", "/health - Health check" ] } @app.post("/predict", response_model=RiskPrediction) async def predict_single(item: POItem, sku_database: Optional[List[Dict]] = None): """Predict risk for a single PO item""" try: prediction = predict_risk_hf(item, sku_database) return prediction except Exception as e: raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}") @app.post("/predict/batch", response_model=BatchPredictionResponse) async def predict_batch(batch: POBatch): """Predict risk for multiple PO items""" try: predictions = [] for idx, item in enumerate(batch.items): prediction = predict_risk_hf(item, batch.sku_database) prediction.item_index = idx predictions.append(prediction) # Create summary risk_counts = {"High": 0, "Medium": 0, "Low": 0} for pred in predictions: risk_counts[pred.predicted_risk] += 1 return BatchPredictionResponse( predictions=predictions, summary=risk_counts ) except Exception as e: raise HTTPException(status_code=500, detail=f"Batch prediction error: {str(e)}") @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "timestamp": datetime.now().isoformat(), "models": { "sbert": SBERT_MODEL_ID, "xgboost": XGBOOST_MODEL_ID } } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)