SBERT_NLP / app.py
Hiratax's picture
Update app.py
565170a verified
import os
import json
import torch
import io
import re
import numpy as np
import easyocr
import cv2
from fastapi import FastAPI, UploadFile, File, Form
from sentence_transformers import SentenceTransformer, util
from rapidfuzz import fuzz, process
from supabase import create_client, Client
from PIL import Image
app = FastAPI()
@app.get("/")
def home():
return {"status": "online", "service": "OCR-NLP-Engine"}
SUPABASE_URL = os.environ.get("SUPABASE_URL")
SUPABASE_KEY = os.environ.get("SUPABASE_SERVICE_KEY")
supabase: Client = None
model = None
reader = None
MASTER_PRODUCTS = []
MASTER_SKUS = []
MASTER_EMBEDDINGS = None
@app.on_event("startup")
async def startup_event():
global supabase, model, reader
if SUPABASE_URL and SUPABASE_KEY:
try:
supabase = create_client(SUPABASE_URL, SUPABASE_KEY)
except Exception:
pass
model = SentenceTransformer('all-MiniLM-L6-v2')
reader = easyocr.Reader(['id', 'en'], gpu=False)
await refresh_data_internal()
async def refresh_data_internal():
global MASTER_PRODUCTS, MASTER_EMBEDDINGS, MASTER_SKUS
if not supabase:
return
try:
response = supabase.table("master_products").select("id, name, sku, embedding").execute()
data = response.data
if not data:
return
names, skus, embs = [], [], []
for item in data:
if item['name']:
names.append(item['name'])
skus.append(item.get('sku', 'N/A'))
emb = item.get('embedding')
if isinstance(emb, str):
emb = json.loads(emb.replace('{', '[').replace('}', ']'))
embs.append(emb)
MASTER_PRODUCTS, MASTER_SKUS = names, skus
if embs:
MASTER_EMBEDDINGS = torch.tensor(embs, dtype=torch.float)
except Exception:
pass
def improve_image(img_np):
gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
gray = cv2.resize(gray, None, fx=2, fy=2, interpolation=cv2.INTER_CUBIC)
denoised = cv2.fastNlMeansDenoising(gray, h=10)
kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
sharpened = cv2.filter2D(denoised, -1, kernel)
return cv2.adaptiveThreshold(sharpened, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 31, 2)
def calculate_usage_from_text(text):
clean_text = text.lower().replace(',', '.')
gb_pattern = r'(\d+(?:\.\d+)?)\s*(?:gb|mb)'
matches = re.findall(gb_pattern, clean_text)
values = [float(m) for m in matches if 0.01 < float(m) < 2000]
if len(values) >= 2:
return round(max(values) - min(values), 2), round(min(values), 2), "Max-Min Calc"
elif len(values) == 1:
return 0.0, values[0], "Single Value"
return 0.0, 0.0, "No Data"
def run_match(text):
if not MASTER_PRODUCTS or MASTER_EMBEDDINGS is None:
return None
input_emb = model.encode(text, convert_to_tensor=True)
scores = util.cos_sim(input_emb, MASTER_EMBEDDINGS)
best_idx = torch.argmax(scores).item()
s_score = scores[0][best_idx].item()
s_name = MASTER_PRODUCTS[best_idx]
f_res = process.extractOne(text, MASTER_PRODUCTS, scorer=fuzz.token_set_ratio)
f_name = f_res[0]
f_score = f_res[1] / 100.0
if s_score > 0.85:
final_score = (s_score * 0.8) + (f_score * 0.2)
best_name = s_name
else:
final_score = (f_score * 0.6) + (s_score * 0.4)
best_name = f_name
if final_score >= 0.30:
idx = MASTER_PRODUCTS.index(best_name)
return {
"master_name": best_name,
"sku": MASTER_SKUS[idx],
"confidence": round(final_score * 100, 2)
}
return None
@app.post("/process-invoice")
async def process_invoice(file: UploadFile = File(...)):
try:
content = await file.read()
img = Image.open(io.BytesIO(content)).convert("RGB")
img_np = np.array(img)
processed = improve_image(img_np)
results = reader.readtext(processed, detail=0)
full_text = " ".join(results)
lines = re.split(r'\n|,|\s{2,}', full_text)
matches = []
for line in lines:
clean_line = line.strip()
if len(clean_line) > 3:
match = run_match(clean_line)
if match:
match["original_text"] = clean_line
matches.append(match)
return {"type": "invoice", "matches": matches}
except Exception as e:
return {"error": str(e)}
@app.post("/process-quota")
async def process_quota(file: UploadFile = File(...)):
try:
content = await file.read()
img = Image.open(io.BytesIO(content)).convert("RGB")
img_np = np.array(img)
processed = improve_image(img_np)
results = reader.readtext(processed, detail=0)
raw_text = " ".join(results)
used, rem, method = calculate_usage_from_text(raw_text)
return {"used": used, "remaining": rem, "method": method}
except Exception as e:
return {"error": str(e)}
@app.post("/process-ocr")
async def process_ocr(file: UploadFile = File(...)):
try:
content = await file.read()
img = Image.open(io.BytesIO(content)).convert("RGB")
img_np = np.array(img)
processed = improve_image(img_np)
results = reader.readtext(processed, detail=0)
return {"type": "raw_ocr", "lines": results}
except Exception as e:
return {"error": str(e)}
@app.post("/refresh-data")
async def trigger_refresh():
await refresh_data_internal()
return {"status": "success"}