ANSPOValidator / fastapi_hf.py
Manveer
Add application file
757cb88
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import Optional, List, Dict
import pandas as pd
import numpy as np
from datetime import datetime
from huggingface_hub import InferenceClient
import requests
import os
from collections import Counter
import re
app = FastAPI(
title="PO Risk Validator API",
description="AI-powered Purchase Order risk assessment API using HuggingFace models",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize HuggingFace clients
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN") # Set your HF token as environment variable
# Replace with your actual model IDs once uploaded
SBERT_MODEL_ID = "your-username/po-validator-sbert"
XGBOOST_MODEL_ID = "your-username/po-validator-xgboost"
# Initialize inference clients
sbert_client = InferenceClient(model=SBERT_MODEL_ID, token=HF_TOKEN)
class POItem(BaseModel):
product_name: str = Field(..., description="Product description")
quantity: float = Field(..., gt=0, description="Order quantity")
delivery_date: str = Field(..., description="Expected delivery date (YYYY-MM-DD)")
filename: str = Field(..., description="Source document filename")
company_name: Optional[str] = Field(None, description="Supplier company name")
class POBatch(BaseModel):
items: List[POItem]
sku_database: Optional[List[Dict]] = Field(None, description="SKU reference database")
class RiskPrediction(BaseModel):
item_index: int
product_name: str
predicted_risk: str
risk_score: float
confidence: float
matched_sku_code: Optional[str] = None
matched_sku_name: Optional[str] = None
cosine_similarity: Optional[float] = None
missing_field_score: float
semantic_signal: Optional[float] = None
class BatchPredictionResponse(BaseModel):
predictions: List[RiskPrediction]
summary: Dict[str, int]
def missing_field_score_v2(item: POItem) -> float:
"""Calculate missing field score"""
score = 0
name = str(item.product_name).strip().lower()
words = name.split()
if not name:
score += 2
elif len(words) < 3:
score += 1
if item.quantity <= 0:
score += 2
if not item.delivery_date:
score += 1
else:
try:
delivery_dt = pd.to_datetime(item.delivery_date)
days_to_delivery = (delivery_dt - datetime.now()).days
if days_to_delivery <= 0:
score += 1
except:
score += 1
if not str(item.filename).strip():
score += 0.5
if not str(item.company_name or "").strip():
score += 0.5
return score / 8
def get_semantic_similarity(product_text: str, sku_texts: List[str]) -> Dict:
"""Get semantic similarity using HuggingFace SBERT model"""
try:
# Use HuggingFace Inference API for embeddings
response = sbert_client.feature_extraction([product_text] + sku_texts)
if len(response) < 2:
return {"similarity": 0.0, "matched_index": 0, "matched_sku": ""}
# Calculate cosine similarity
product_embedding = np.array(response[0])
sku_embeddings = np.array(response[1:])
# Compute cosine similarities
similarities = []
for sku_emb in sku_embeddings:
sim = np.dot(product_embedding, sku_emb) / (
np.linalg.norm(product_embedding) * np.linalg.norm(sku_emb)
)
similarities.append(sim)
best_match_idx = np.argmax(similarities)
best_similarity = similarities[best_match_idx]
return {
"similarity": float(best_similarity),
"matched_index": int(best_match_idx),
"similarities": similarities
}
except Exception as e:
print(f"Error in semantic similarity: {e}")
return {"similarity": 0.0, "matched_index": 0, "similarities": []}
def predict_risk_hf(item: POItem, sku_database: Optional[List[Dict]] = None) -> RiskPrediction:
"""Predict risk using HuggingFace models"""
# Calculate features
missing_score = missing_field_score_v2(item)
# Semantic matching if SKU database provided
matched_sku_code = None
matched_sku_name = None
cosine_similarity = 0.0
semantic_signal = 0.0
if sku_database:
sku_texts = [sku.get("Product_Name", "") for sku in sku_database]
sku_codes = [sku.get("SKU_Code", "") for sku in sku_database]
sem_result = get_semantic_similarity(item.product_name, sku_texts)
cosine_similarity = sem_result["similarity"]
matched_idx = sem_result["matched_index"]
if matched_idx < len(sku_database):
matched_sku_code = sku_codes[matched_idx]
matched_sku_name = sku_texts[matched_idx]
# Simple semantic signal calculation
semantic_signal = cosine_similarity - 0.5 # Normalized around 0
# Filename encoding
filename_score = 0.0
if item.filename:
filename_str = str(item.filename).lower()
if filename_str.startswith(('invoice', 'txn', 'mgt', 'manzillglobe', 'daljit')):
filename_score = 3.0 # High risk
elif filename_str.startswith(('order', 'po', 'ref', 'manzill')):
filename_score = 1.0 # Low risk
else:
filename_score = 2.0 # Medium risk
# Simple rule-based risk prediction (replace with actual XGBoost model call)
risk_factors = [
missing_score * 2, # Weight missing fields heavily
(1.0 - cosine_similarity) if sku_database else 0.5, # Low similarity = higher risk
filename_score / 4.0, # Normalize filename score
]
risk_score = np.mean(risk_factors)
# Determine risk level
if risk_score > 0.6:
predicted_risk = "High"
confidence = min(0.95, 0.6 + risk_score * 0.4)
elif risk_score > 0.3:
predicted_risk = "Medium"
confidence = 0.7
else:
predicted_risk = "Low"
confidence = min(0.95, 0.8 - risk_score * 0.3)
return RiskPrediction(
item_index=0, # Will be set by caller
product_name=item.product_name,
predicted_risk=predicted_risk,
risk_score=round(risk_score, 3),
confidence=round(confidence, 3),
matched_sku_code=matched_sku_code,
matched_sku_name=matched_sku_name,
cosine_similarity=round(cosine_similarity, 3) if cosine_similarity else None,
missing_field_score=round(missing_score, 3),
semantic_signal=round(semantic_signal, 3) if semantic_signal else None
)
@app.get("/")
async def root():
return {
"message": "PO Risk Validator API",
"version": "1.0.0",
"endpoints": [
"/predict - Single PO prediction",
"/predict/batch - Batch PO predictions",
"/health - Health check"
]
}
@app.post("/predict", response_model=RiskPrediction)
async def predict_single(item: POItem, sku_database: Optional[List[Dict]] = None):
"""Predict risk for a single PO item"""
try:
prediction = predict_risk_hf(item, sku_database)
return prediction
except Exception as e:
raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
@app.post("/predict/batch", response_model=BatchPredictionResponse)
async def predict_batch(batch: POBatch):
"""Predict risk for multiple PO items"""
try:
predictions = []
for idx, item in enumerate(batch.items):
prediction = predict_risk_hf(item, batch.sku_database)
prediction.item_index = idx
predictions.append(prediction)
# Create summary
risk_counts = {"High": 0, "Medium": 0, "Low": 0}
for pred in predictions:
risk_counts[pred.predicted_risk] += 1
return BatchPredictionResponse(
predictions=predictions,
summary=risk_counts
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Batch prediction error: {str(e)}")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"timestamp": datetime.now().isoformat(),
"models": {
"sbert": SBERT_MODEL_ID,
"xgboost": XGBOOST_MODEL_ID
}
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)