Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from typing import Optional, List, Dict | |
| import pandas as pd | |
| import numpy as np | |
| from datetime import datetime | |
| from huggingface_hub import InferenceClient | |
| import requests | |
| import os | |
| from collections import Counter | |
| import re | |
| app = FastAPI( | |
| title="PO Risk Validator API", | |
| description="AI-powered Purchase Order risk assessment API using HuggingFace models", | |
| version="1.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize HuggingFace clients | |
| HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN") # Set your HF token as environment variable | |
| # Replace with your actual model IDs once uploaded | |
| SBERT_MODEL_ID = "your-username/po-validator-sbert" | |
| XGBOOST_MODEL_ID = "your-username/po-validator-xgboost" | |
| # Initialize inference clients | |
| sbert_client = InferenceClient(model=SBERT_MODEL_ID, token=HF_TOKEN) | |
| class POItem(BaseModel): | |
| product_name: str = Field(..., description="Product description") | |
| quantity: float = Field(..., gt=0, description="Order quantity") | |
| delivery_date: str = Field(..., description="Expected delivery date (YYYY-MM-DD)") | |
| filename: str = Field(..., description="Source document filename") | |
| company_name: Optional[str] = Field(None, description="Supplier company name") | |
| class POBatch(BaseModel): | |
| items: List[POItem] | |
| sku_database: Optional[List[Dict]] = Field(None, description="SKU reference database") | |
| class RiskPrediction(BaseModel): | |
| item_index: int | |
| product_name: str | |
| predicted_risk: str | |
| risk_score: float | |
| confidence: float | |
| matched_sku_code: Optional[str] = None | |
| matched_sku_name: Optional[str] = None | |
| cosine_similarity: Optional[float] = None | |
| missing_field_score: float | |
| semantic_signal: Optional[float] = None | |
| class BatchPredictionResponse(BaseModel): | |
| predictions: List[RiskPrediction] | |
| summary: Dict[str, int] | |
| def missing_field_score_v2(item: POItem) -> float: | |
| """Calculate missing field score""" | |
| score = 0 | |
| name = str(item.product_name).strip().lower() | |
| words = name.split() | |
| if not name: | |
| score += 2 | |
| elif len(words) < 3: | |
| score += 1 | |
| if item.quantity <= 0: | |
| score += 2 | |
| if not item.delivery_date: | |
| score += 1 | |
| else: | |
| try: | |
| delivery_dt = pd.to_datetime(item.delivery_date) | |
| days_to_delivery = (delivery_dt - datetime.now()).days | |
| if days_to_delivery <= 0: | |
| score += 1 | |
| except: | |
| score += 1 | |
| if not str(item.filename).strip(): | |
| score += 0.5 | |
| if not str(item.company_name or "").strip(): | |
| score += 0.5 | |
| return score / 8 | |
| def get_semantic_similarity(product_text: str, sku_texts: List[str]) -> Dict: | |
| """Get semantic similarity using HuggingFace SBERT model""" | |
| try: | |
| # Use HuggingFace Inference API for embeddings | |
| response = sbert_client.feature_extraction([product_text] + sku_texts) | |
| if len(response) < 2: | |
| return {"similarity": 0.0, "matched_index": 0, "matched_sku": ""} | |
| # Calculate cosine similarity | |
| product_embedding = np.array(response[0]) | |
| sku_embeddings = np.array(response[1:]) | |
| # Compute cosine similarities | |
| similarities = [] | |
| for sku_emb in sku_embeddings: | |
| sim = np.dot(product_embedding, sku_emb) / ( | |
| np.linalg.norm(product_embedding) * np.linalg.norm(sku_emb) | |
| ) | |
| similarities.append(sim) | |
| best_match_idx = np.argmax(similarities) | |
| best_similarity = similarities[best_match_idx] | |
| return { | |
| "similarity": float(best_similarity), | |
| "matched_index": int(best_match_idx), | |
| "similarities": similarities | |
| } | |
| except Exception as e: | |
| print(f"Error in semantic similarity: {e}") | |
| return {"similarity": 0.0, "matched_index": 0, "similarities": []} | |
| def predict_risk_hf(item: POItem, sku_database: Optional[List[Dict]] = None) -> RiskPrediction: | |
| """Predict risk using HuggingFace models""" | |
| # Calculate features | |
| missing_score = missing_field_score_v2(item) | |
| # Semantic matching if SKU database provided | |
| matched_sku_code = None | |
| matched_sku_name = None | |
| cosine_similarity = 0.0 | |
| semantic_signal = 0.0 | |
| if sku_database: | |
| sku_texts = [sku.get("Product_Name", "") for sku in sku_database] | |
| sku_codes = [sku.get("SKU_Code", "") for sku in sku_database] | |
| sem_result = get_semantic_similarity(item.product_name, sku_texts) | |
| cosine_similarity = sem_result["similarity"] | |
| matched_idx = sem_result["matched_index"] | |
| if matched_idx < len(sku_database): | |
| matched_sku_code = sku_codes[matched_idx] | |
| matched_sku_name = sku_texts[matched_idx] | |
| # Simple semantic signal calculation | |
| semantic_signal = cosine_similarity - 0.5 # Normalized around 0 | |
| # Filename encoding | |
| filename_score = 0.0 | |
| if item.filename: | |
| filename_str = str(item.filename).lower() | |
| if filename_str.startswith(('invoice', 'txn', 'mgt', 'manzillglobe', 'daljit')): | |
| filename_score = 3.0 # High risk | |
| elif filename_str.startswith(('order', 'po', 'ref', 'manzill')): | |
| filename_score = 1.0 # Low risk | |
| else: | |
| filename_score = 2.0 # Medium risk | |
| # Simple rule-based risk prediction (replace with actual XGBoost model call) | |
| risk_factors = [ | |
| missing_score * 2, # Weight missing fields heavily | |
| (1.0 - cosine_similarity) if sku_database else 0.5, # Low similarity = higher risk | |
| filename_score / 4.0, # Normalize filename score | |
| ] | |
| risk_score = np.mean(risk_factors) | |
| # Determine risk level | |
| if risk_score > 0.6: | |
| predicted_risk = "High" | |
| confidence = min(0.95, 0.6 + risk_score * 0.4) | |
| elif risk_score > 0.3: | |
| predicted_risk = "Medium" | |
| confidence = 0.7 | |
| else: | |
| predicted_risk = "Low" | |
| confidence = min(0.95, 0.8 - risk_score * 0.3) | |
| return RiskPrediction( | |
| item_index=0, # Will be set by caller | |
| product_name=item.product_name, | |
| predicted_risk=predicted_risk, | |
| risk_score=round(risk_score, 3), | |
| confidence=round(confidence, 3), | |
| matched_sku_code=matched_sku_code, | |
| matched_sku_name=matched_sku_name, | |
| cosine_similarity=round(cosine_similarity, 3) if cosine_similarity else None, | |
| missing_field_score=round(missing_score, 3), | |
| semantic_signal=round(semantic_signal, 3) if semantic_signal else None | |
| ) | |
| async def root(): | |
| return { | |
| "message": "PO Risk Validator API", | |
| "version": "1.0.0", | |
| "endpoints": [ | |
| "/predict - Single PO prediction", | |
| "/predict/batch - Batch PO predictions", | |
| "/health - Health check" | |
| ] | |
| } | |
| async def predict_single(item: POItem, sku_database: Optional[List[Dict]] = None): | |
| """Predict risk for a single PO item""" | |
| try: | |
| prediction = predict_risk_hf(item, sku_database) | |
| return prediction | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}") | |
| async def predict_batch(batch: POBatch): | |
| """Predict risk for multiple PO items""" | |
| try: | |
| predictions = [] | |
| for idx, item in enumerate(batch.items): | |
| prediction = predict_risk_hf(item, batch.sku_database) | |
| prediction.item_index = idx | |
| predictions.append(prediction) | |
| # Create summary | |
| risk_counts = {"High": 0, "Medium": 0, "Low": 0} | |
| for pred in predictions: | |
| risk_counts[pred.predicted_risk] += 1 | |
| return BatchPredictionResponse( | |
| predictions=predictions, | |
| summary=risk_counts | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Batch prediction error: {str(e)}") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "timestamp": datetime.now().isoformat(), | |
| "models": { | |
| "sbert": SBERT_MODEL_ID, | |
| "xgboost": XGBOOST_MODEL_ID | |
| } | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |