Commit
·
8cceab7
1
Parent(s):
790aee0
added id returning
Browse files- app/api/routes.py +1 -1
- app/core/config.py +78 -10
- app/dependencies.py +2 -0
- app/schemas/categories.py +15 -6
- app/services/autocategorizer.py +406 -218
app/api/routes.py
CHANGED
|
@@ -20,7 +20,7 @@ async def categorize_transaction(
|
|
| 20 |
) -> CategorizeResponse | JSONResponse:
|
| 21 |
started_at = time.monotonic()
|
| 22 |
try:
|
| 23 |
-
result = await service.categorize(payload.notes)
|
| 24 |
await api_logger.log_categorization(
|
| 25 |
name="Auto Expense Categorization",
|
| 26 |
status="success",
|
|
|
|
| 20 |
) -> CategorizeResponse | JSONResponse:
|
| 21 |
started_at = time.monotonic()
|
| 22 |
try:
|
| 23 |
+
result = await service.categorize(payload.notes, payload.user_id)
|
| 24 |
await api_logger.log_categorization(
|
| 25 |
name="Auto Expense Categorization",
|
| 26 |
status="success",
|
app/core/config.py
CHANGED
|
@@ -1,3 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from functools import lru_cache
|
| 2 |
|
| 3 |
from pydantic import Field
|
|
@@ -5,24 +11,86 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
| 5 |
|
| 6 |
|
| 7 |
class Settings(BaseSettings):
|
| 8 |
-
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
@lru_cache
|
| 24 |
def get_settings() -> Settings:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
return Settings()
|
| 26 |
|
| 27 |
|
|
|
|
| 28 |
settings = get_settings()
|
|
|
|
| 1 |
+
"""Application configuration settings.
|
| 2 |
+
|
| 3 |
+
This module handles all configuration settings loaded from environment variables.
|
| 4 |
+
Settings are validated using Pydantic and cached for performance.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
from functools import lru_cache
|
| 8 |
|
| 9 |
from pydantic import Field
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
class Settings(BaseSettings):
|
| 14 |
+
"""Application settings loaded from environment variables."""
|
| 15 |
|
| 16 |
+
model_config = SettingsConfigDict(
|
| 17 |
+
env_file=".env",
|
| 18 |
+
env_file_encoding="utf-8",
|
| 19 |
+
extra="ignore",
|
| 20 |
+
case_sensitive=False,
|
| 21 |
+
)
|
| 22 |
|
| 23 |
+
# MongoDB Configuration
|
| 24 |
+
mongo_uri: str = Field(
|
| 25 |
+
...,
|
| 26 |
+
alias="MONGO_URI",
|
| 27 |
+
description="MongoDB connection URI",
|
| 28 |
+
)
|
| 29 |
+
mongo_db: str = Field(
|
| 30 |
+
"expense",
|
| 31 |
+
alias="MONGO_DB",
|
| 32 |
+
description="MongoDB database name",
|
| 33 |
+
)
|
| 34 |
+
mongo_collection: str = Field(
|
| 35 |
+
"headcategories",
|
| 36 |
+
alias="MONGO_COLLECTION",
|
| 37 |
+
description="MongoDB collection name for headcategories",
|
| 38 |
+
)
|
| 39 |
+
mongo_subcategory_collection: str = Field(
|
| 40 |
+
"categories",
|
| 41 |
+
alias="MONGO_SUBCATEGORY_COLLECTION",
|
| 42 |
+
description="MongoDB collection name for categories",
|
| 43 |
+
)
|
| 44 |
+
api_logs_collection: str = Field(
|
| 45 |
+
"api_logs",
|
| 46 |
+
alias="MONGO_API_LOGS_COLLECTION",
|
| 47 |
+
description="MongoDB collection name for API logs",
|
| 48 |
+
)
|
| 49 |
|
| 50 |
+
# OpenAI Configuration
|
| 51 |
+
openai_api_key: str = Field(
|
| 52 |
+
...,
|
| 53 |
+
alias="OPENAI_API_KEY",
|
| 54 |
+
description="OpenAI API key for LLM requests",
|
| 55 |
+
)
|
| 56 |
+
openai_model: str = Field(
|
| 57 |
+
"gpt-4o-mini",
|
| 58 |
+
alias="OPENAI_MODEL",
|
| 59 |
+
description="OpenAI model to use for categorization",
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Performance & Caching Configuration
|
| 63 |
+
category_cache_ttl_seconds: int = Field(
|
| 64 |
+
300,
|
| 65 |
+
alias="CATEGORY_CACHE_TTL",
|
| 66 |
+
description="Time-to-live for category cache in seconds (5 minutes default)",
|
| 67 |
+
ge=60, # Minimum 1 minute
|
| 68 |
+
)
|
| 69 |
+
db_query_timeout_seconds: float = Field(
|
| 70 |
+
5.0,
|
| 71 |
+
alias="DB_QUERY_TIMEOUT",
|
| 72 |
+
description="Timeout for database queries in seconds",
|
| 73 |
+
ge=1.0,
|
| 74 |
+
le=30.0,
|
| 75 |
+
)
|
| 76 |
+
model_api_timeout_seconds: float = Field(
|
| 77 |
+
15.0,
|
| 78 |
+
alias="MODEL_API_TIMEOUT",
|
| 79 |
+
description="Timeout for OpenAI API calls in seconds",
|
| 80 |
+
ge=5.0,
|
| 81 |
+
le=60.0,
|
| 82 |
+
)
|
| 83 |
|
| 84 |
|
| 85 |
@lru_cache
|
| 86 |
def get_settings() -> Settings:
|
| 87 |
+
"""Get cached settings instance.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
Settings: Application settings instance
|
| 91 |
+
"""
|
| 92 |
return Settings()
|
| 93 |
|
| 94 |
|
| 95 |
+
# Global settings instance
|
| 96 |
settings = get_settings()
|
app/dependencies.py
CHANGED
|
@@ -15,6 +15,8 @@ def _get_service() -> AutoCategoryService:
|
|
| 15 |
openai_client=openai_client,
|
| 16 |
model=settings.openai_model,
|
| 17 |
cache_ttl_seconds=settings.category_cache_ttl_seconds,
|
|
|
|
|
|
|
| 18 |
)
|
| 19 |
|
| 20 |
|
|
|
|
| 15 |
openai_client=openai_client,
|
| 16 |
model=settings.openai_model,
|
| 17 |
cache_ttl_seconds=settings.category_cache_ttl_seconds,
|
| 18 |
+
db_timeout_seconds=settings.db_query_timeout_seconds,
|
| 19 |
+
model_timeout_seconds=settings.model_api_timeout_seconds,
|
| 20 |
)
|
| 21 |
|
| 22 |
|
app/schemas/categories.py
CHANGED
|
@@ -2,19 +2,28 @@ from __future__ import annotations
|
|
| 2 |
|
| 3 |
from typing import Optional
|
| 4 |
|
| 5 |
-
from
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
class CategorizeRequest(BaseModel):
|
| 9 |
notes: str = Field(..., min_length=1, description="Full transaction note.")
|
| 10 |
-
user_id:
|
| 11 |
-
None, description="Optional user identifier associated with the request."
|
| 12 |
-
)
|
| 13 |
|
| 14 |
|
| 15 |
class CategoryPrediction(BaseModel):
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
class CategorizeResponse(BaseModel):
|
|
|
|
| 2 |
|
| 3 |
from typing import Optional
|
| 4 |
|
| 5 |
+
from bson import ObjectId
|
| 6 |
+
from pydantic import BaseModel, Field, field_validator
|
| 7 |
|
| 8 |
|
| 9 |
class CategorizeRequest(BaseModel):
|
| 10 |
notes: str = Field(..., min_length=1, description="Full transaction note.")
|
| 11 |
+
user_id: str = Field(..., description="User identifier associated with the request.")
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
class CategoryPrediction(BaseModel):
|
| 15 |
+
headcategory_id: str = Field(..., description="High-level category ObjectId.")
|
| 16 |
+
headcategory_title: str = Field(..., description="High-level category title.")
|
| 17 |
+
category_id: str = Field(..., description="Specific subcategory ObjectId chosen by the model.")
|
| 18 |
+
category_title: str = Field(..., description="Specific subcategory title chosen by the model.")
|
| 19 |
+
|
| 20 |
+
@field_validator('headcategory_id', 'category_id')
|
| 21 |
+
@classmethod
|
| 22 |
+
def validate_object_id(cls, v: str) -> str:
|
| 23 |
+
"""Validate that the string is a valid ObjectId."""
|
| 24 |
+
if not ObjectId.is_valid(v):
|
| 25 |
+
raise ValueError(f"Invalid ObjectId: {v}")
|
| 26 |
+
return v
|
| 27 |
|
| 28 |
|
| 29 |
class CategorizeResponse(BaseModel):
|
app/services/autocategorizer.py
CHANGED
|
@@ -4,8 +4,9 @@ import ast
|
|
| 4 |
import asyncio
|
| 5 |
import json
|
| 6 |
import re
|
|
|
|
| 7 |
import time
|
| 8 |
-
from typing import Callable, Dict, List, Optional
|
| 9 |
|
| 10 |
from bson import ObjectId
|
| 11 |
from fastapi import HTTPException
|
|
@@ -19,127 +20,6 @@ from app.schemas.categories import CategoryPrediction
|
|
| 19 |
class AutoCategoryService:
|
| 20 |
"""Classifies transaction notes into the closest Mongo-backed category."""
|
| 21 |
|
| 22 |
-
# Curated categories requested by the client. When enabled via settings.use_static_categories,
|
| 23 |
-
# we bypass Mongo reads to avoid noisy data and long scans.
|
| 24 |
-
_STATIC_CATEGORIES: List[Dict[str, object]] = [
|
| 25 |
-
{
|
| 26 |
-
"title": "Food & Drinks",
|
| 27 |
-
"subcategories": ["Groceries", "Restaurant, Fast - Food", "Bar, Cafe", "Food & Drink"],
|
| 28 |
-
},
|
| 29 |
-
{
|
| 30 |
-
"title": "Investments",
|
| 31 |
-
"subcategories": [
|
| 32 |
-
"Investments",
|
| 33 |
-
"Realty",
|
| 34 |
-
"Vehicles, Chattels",
|
| 35 |
-
"Finacial investments",
|
| 36 |
-
"Savings",
|
| 37 |
-
"Collections",
|
| 38 |
-
],
|
| 39 |
-
},
|
| 40 |
-
{
|
| 41 |
-
"title": "Communication,PC",
|
| 42 |
-
"subcategories": ["Communication,PC", "Phone", "Internet", "Software, app, games", "Postal services"],
|
| 43 |
-
},
|
| 44 |
-
{
|
| 45 |
-
"title": "Financial Expenses",
|
| 46 |
-
"subcategories": [
|
| 47 |
-
"Financial expenses",
|
| 48 |
-
"Taxes",
|
| 49 |
-
"Insurances",
|
| 50 |
-
"Loan, interests",
|
| 51 |
-
"Fines",
|
| 52 |
-
"Advisory",
|
| 53 |
-
"Charges, Fees",
|
| 54 |
-
"Child Support",
|
| 55 |
-
],
|
| 56 |
-
},
|
| 57 |
-
{
|
| 58 |
-
"title": "Life & Entertainment",
|
| 59 |
-
"subcategories": [
|
| 60 |
-
"Life & Entertainment",
|
| 61 |
-
"Health, Care, Doctor",
|
| 62 |
-
"Wellness, Beauty",
|
| 63 |
-
"Active sport, Fitness",
|
| 64 |
-
"Culture, sport events",
|
| 65 |
-
"Life events",
|
| 66 |
-
"Hobbies",
|
| 67 |
-
"Education, Development",
|
| 68 |
-
"Books, Audio, subscription",
|
| 69 |
-
"TV, Streaming",
|
| 70 |
-
"Holiday, Trip, Hotels",
|
| 71 |
-
"Charity, Gifts",
|
| 72 |
-
"Alcohol, Tobacco",
|
| 73 |
-
"Lottery, Gamblings",
|
| 74 |
-
],
|
| 75 |
-
},
|
| 76 |
-
{
|
| 77 |
-
"title": "Vehicle",
|
| 78 |
-
"subcategories": [
|
| 79 |
-
"Vehicle",
|
| 80 |
-
"Fuel",
|
| 81 |
-
"Parking",
|
| 82 |
-
"Vehicle maintenance",
|
| 83 |
-
"Rentals",
|
| 84 |
-
"Vehicle insurance",
|
| 85 |
-
"Leasing",
|
| 86 |
-
],
|
| 87 |
-
},
|
| 88 |
-
{
|
| 89 |
-
"title": "Transportation",
|
| 90 |
-
"subcategories": ["Transportation", "Public transport", "Taxi", "Long distance", "Business trips"],
|
| 91 |
-
},
|
| 92 |
-
{
|
| 93 |
-
"title": "Housing",
|
| 94 |
-
"subcategories": [
|
| 95 |
-
"Housing",
|
| 96 |
-
"Rent",
|
| 97 |
-
"Mortgage",
|
| 98 |
-
"Energy, utilities",
|
| 99 |
-
"Services",
|
| 100 |
-
"Maintenance, repairs",
|
| 101 |
-
"Property insurance",
|
| 102 |
-
],
|
| 103 |
-
},
|
| 104 |
-
{
|
| 105 |
-
"title": "Shopping",
|
| 106 |
-
"subcategories": [
|
| 107 |
-
"Shopping",
|
| 108 |
-
"Clothes & shoes",
|
| 109 |
-
"Jewels & Accessories",
|
| 110 |
-
"Health & Beauty",
|
| 111 |
-
"Kids",
|
| 112 |
-
"Home & Garden",
|
| 113 |
-
"Pets & Animals",
|
| 114 |
-
"Electronics",
|
| 115 |
-
"Gift",
|
| 116 |
-
"Stationary",
|
| 117 |
-
"Free time",
|
| 118 |
-
"Chemist",
|
| 119 |
-
],
|
| 120 |
-
},
|
| 121 |
-
{
|
| 122 |
-
"title": "Income",
|
| 123 |
-
"subcategories": [
|
| 124 |
-
"Income",
|
| 125 |
-
"Wage, Invoices",
|
| 126 |
-
"Sale",
|
| 127 |
-
"Rental income",
|
| 128 |
-
"Dues & grants",
|
| 129 |
-
"Lending, renting",
|
| 130 |
-
"Checks, coupons",
|
| 131 |
-
"Lottery, gambling",
|
| 132 |
-
"Refunds",
|
| 133 |
-
"Child support",
|
| 134 |
-
"Gifts",
|
| 135 |
-
"Account Manage",
|
| 136 |
-
],
|
| 137 |
-
},
|
| 138 |
-
]
|
| 139 |
-
|
| 140 |
-
_categories_timeout_seconds = 15.0
|
| 141 |
-
_model_timeout_seconds = 20.0
|
| 142 |
-
|
| 143 |
def __init__(
|
| 144 |
self,
|
| 145 |
collection_getter: Callable[[], AsyncIOMotorCollection],
|
|
@@ -147,15 +27,20 @@ class AutoCategoryService:
|
|
| 147 |
openai_client: AsyncOpenAI,
|
| 148 |
model: str,
|
| 149 |
cache_ttl_seconds: int,
|
|
|
|
|
|
|
| 150 |
) -> None:
|
| 151 |
self._collection_getter = collection_getter
|
| 152 |
self._subcategory_collection_getter = subcategory_collection_getter
|
| 153 |
self._openai_client = openai_client
|
| 154 |
self._model = model
|
| 155 |
self._cache_ttl_seconds = cache_ttl_seconds
|
| 156 |
-
self.
|
| 157 |
-
self.
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
def _collection(self) -> AsyncIOMotorCollection:
|
| 161 |
return self._collection_getter()
|
|
@@ -163,72 +48,220 @@ class AutoCategoryService:
|
|
| 163 |
def _subcategory_collection(self) -> AsyncIOMotorCollection:
|
| 164 |
return self._subcategory_collection_getter()
|
| 165 |
|
| 166 |
-
async def categorize(self, notes: str) -> CategoryPrediction:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
try:
|
| 168 |
-
|
| 169 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
)
|
| 171 |
except asyncio.TimeoutError as exc:
|
| 172 |
raise HTTPException(status_code=504, detail="Timed out loading categories from database.") from exc
|
| 173 |
except Exception as exc:
|
| 174 |
raise HTTPException(status_code=502, detail="Failed to load categories from database.") from exc
|
| 175 |
|
| 176 |
-
if not categories:
|
| 177 |
-
raise HTTPException(status_code=500, detail="No categories
|
| 178 |
|
| 179 |
-
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
| 181 |
"Transaction note:\n"
|
| 182 |
f"{notes}\n\n"
|
| 183 |
-
"Available categories
|
| 184 |
f"{formatted_categories}\n\n"
|
| 185 |
-
"Respond with the exact title
|
| 186 |
)
|
| 187 |
|
| 188 |
-
|
| 189 |
model=self._model,
|
| 190 |
-
|
| 191 |
{
|
| 192 |
"role": "system",
|
| 193 |
"content": (
|
| 194 |
"You classify financial transactions into the closest category. "
|
| 195 |
-
"Only use the provided title
|
| 196 |
-
"Output valid JSON with
|
| 197 |
),
|
| 198 |
},
|
| 199 |
-
{"role": "user", "content":
|
| 200 |
],
|
| 201 |
)
|
| 202 |
|
| 203 |
try:
|
| 204 |
-
|
| 205 |
-
self._create_model_response(
|
| 206 |
-
timeout=self._model_timeout_seconds,
|
| 207 |
-
)
|
| 208 |
-
except TypeError as exc:
|
| 209 |
-
# Older openai-python clients (pre 1.3x) do not yet support response_format.
|
| 210 |
-
if "response_format" not in str(exc):
|
| 211 |
-
raise
|
| 212 |
-
response = await asyncio.wait_for(
|
| 213 |
-
self._openai_client.responses.create(**request_payload),
|
| 214 |
timeout=self._model_timeout_seconds,
|
| 215 |
)
|
| 216 |
except asyncio.TimeoutError as exc:
|
| 217 |
-
raise HTTPException(status_code=504, detail="Timed out waiting for model response.") from exc
|
| 218 |
except Exception as exc:
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
try:
|
| 222 |
-
|
| 223 |
except ValueError as exc:
|
| 224 |
-
raise HTTPException(status_code=502, detail="Failed to parse model output.") from exc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
raise HTTPException(status_code=502, detail="Model response missing category fields.")
|
| 230 |
|
| 231 |
-
return CategoryPrediction(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
def _parse_response_payload(self, response) -> Dict[str, object]:
|
| 234 |
raw_text = self._extract_response_text(response)
|
|
@@ -252,6 +285,14 @@ class AutoCategoryService:
|
|
| 252 |
|
| 253 |
@staticmethod
|
| 254 |
def _extract_response_text(response) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
text = getattr(response, "output_text", "") or ""
|
| 256 |
if isinstance(text, str) and text.strip():
|
| 257 |
return text.strip()
|
|
@@ -328,74 +369,221 @@ class AutoCategoryService:
|
|
| 328 |
|
| 329 |
return None
|
| 330 |
|
| 331 |
-
async def
|
| 332 |
-
|
|
|
|
| 333 |
now = time.monotonic()
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
|
| 384 |
-
|
| 385 |
-
self._last_loaded = now
|
| 386 |
-
return categories
|
| 387 |
|
| 388 |
async def _create_model_response(self, request_payload: Dict[str, object]):
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
|
| 394 |
@staticmethod
|
| 395 |
-
def
|
|
|
|
| 396 |
lines = []
|
| 397 |
for category in categories:
|
| 398 |
subs = category.get("subcategories") or []
|
| 399 |
-
subs_text = ", ".join(subs) if subs else "Unspecified"
|
| 400 |
lines.append(f"- {category.get('title', 'Unknown')}: {subs_text}")
|
| 401 |
return "\n".join(lines)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import asyncio
|
| 5 |
import json
|
| 6 |
import re
|
| 7 |
+
import string
|
| 8 |
import time
|
| 9 |
+
from typing import Callable, Dict, List, Optional, Tuple
|
| 10 |
|
| 11 |
from bson import ObjectId
|
| 12 |
from fastapi import HTTPException
|
|
|
|
| 20 |
class AutoCategoryService:
|
| 21 |
"""Classifies transaction notes into the closest Mongo-backed category."""
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
def __init__(
|
| 24 |
self,
|
| 25 |
collection_getter: Callable[[], AsyncIOMotorCollection],
|
|
|
|
| 27 |
openai_client: AsyncOpenAI,
|
| 28 |
model: str,
|
| 29 |
cache_ttl_seconds: int,
|
| 30 |
+
db_timeout_seconds: float,
|
| 31 |
+
model_timeout_seconds: float,
|
| 32 |
) -> None:
|
| 33 |
self._collection_getter = collection_getter
|
| 34 |
self._subcategory_collection_getter = subcategory_collection_getter
|
| 35 |
self._openai_client = openai_client
|
| 36 |
self._model = model
|
| 37 |
self._cache_ttl_seconds = cache_ttl_seconds
|
| 38 |
+
self._db_timeout_seconds = db_timeout_seconds
|
| 39 |
+
self._model_timeout_seconds = model_timeout_seconds
|
| 40 |
+
|
| 41 |
+
# User-specific cache for headcategories: {user_id: (data, timestamp)}
|
| 42 |
+
self._headcategories_cache: Dict[str, Tuple[Dict[str, object], float]] = {}
|
| 43 |
+
self._cache_lock = asyncio.Lock()
|
| 44 |
|
| 45 |
def _collection(self) -> AsyncIOMotorCollection:
|
| 46 |
return self._collection_getter()
|
|
|
|
| 48 |
def _subcategory_collection(self) -> AsyncIOMotorCollection:
|
| 49 |
return self._subcategory_collection_getter()
|
| 50 |
|
| 51 |
+
async def categorize(self, notes: str, user_id: str) -> CategoryPrediction:
|
| 52 |
+
"""Categorize transaction notes using a two-step approach:
|
| 53 |
+
1. First match notes to a headcategory title
|
| 54 |
+
2. Then match notes to a category within that headcategory
|
| 55 |
+
"""
|
| 56 |
+
# Step 1: Fetch all headcategories for the user (with caching)
|
| 57 |
try:
|
| 58 |
+
headcategories_data = await asyncio.wait_for(
|
| 59 |
+
self._get_headcategories_cached(user_id), timeout=self._db_timeout_seconds
|
| 60 |
+
)
|
| 61 |
+
except asyncio.TimeoutError as exc:
|
| 62 |
+
raise HTTPException(status_code=504, detail="Timed out loading headcategories from database.") from exc
|
| 63 |
+
except Exception as exc:
|
| 64 |
+
raise HTTPException(status_code=502, detail="Failed to load headcategories from database.") from exc
|
| 65 |
+
|
| 66 |
+
if not headcategories_data or not headcategories_data.get("headcategories"):
|
| 67 |
+
raise HTTPException(status_code=500, detail="No headcategories configured for this user.")
|
| 68 |
+
|
| 69 |
+
# Step 2: Use LLM to match notes to a headcategory title
|
| 70 |
+
headcategory_titles = [hc.get("title", "") for hc in headcategories_data["headcategories"]]
|
| 71 |
+
formatted_headcategories = "\n".join([f"- {title}" for title in headcategory_titles if title])
|
| 72 |
+
|
| 73 |
+
headcategory_prompt = (
|
| 74 |
+
"Transaction note:\n"
|
| 75 |
+
f"{notes}\n\n"
|
| 76 |
+
"Available headcategories:\n"
|
| 77 |
+
f"{formatted_headcategories}\n\n"
|
| 78 |
+
"Respond with the exact headcategory title from the list above that best matches this transaction."
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
headcategory_request = dict(
|
| 82 |
+
model=self._model,
|
| 83 |
+
messages=[
|
| 84 |
+
{
|
| 85 |
+
"role": "system",
|
| 86 |
+
"content": (
|
| 87 |
+
"You classify financial transactions into the closest headcategory. "
|
| 88 |
+
"Only use the provided headcategory title options. "
|
| 89 |
+
"Output valid JSON with key 'title'."
|
| 90 |
+
),
|
| 91 |
+
},
|
| 92 |
+
{"role": "user", "content": headcategory_prompt},
|
| 93 |
+
],
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
headcategory_response = await asyncio.wait_for(
|
| 98 |
+
self._create_model_response(headcategory_request),
|
| 99 |
+
timeout=self._model_timeout_seconds,
|
| 100 |
+
)
|
| 101 |
+
except asyncio.TimeoutError as exc:
|
| 102 |
+
raise HTTPException(status_code=504, detail="Timed out waiting for headcategory model response.") from exc
|
| 103 |
+
except Exception as exc:
|
| 104 |
+
error_msg = str(exc)
|
| 105 |
+
raise HTTPException(
|
| 106 |
+
status_code=502,
|
| 107 |
+
detail=f"Failed to call the model API for headcategory: {error_msg}"
|
| 108 |
+
) from exc
|
| 109 |
+
|
| 110 |
+
try:
|
| 111 |
+
headcategory_payload = self._parse_response_payload(headcategory_response)
|
| 112 |
+
except ValueError as exc:
|
| 113 |
+
raise HTTPException(status_code=502, detail="Failed to parse headcategory model output.") from exc
|
| 114 |
+
|
| 115 |
+
matched_headcategory_title = headcategory_payload.get("title")
|
| 116 |
+
if not isinstance(matched_headcategory_title, str):
|
| 117 |
+
raise HTTPException(status_code=502, detail="Model response missing headcategory title field.")
|
| 118 |
+
|
| 119 |
+
# Step 3: Find the matched headcategory and get its categories (optimized lookup)
|
| 120 |
+
matched_headcategory = None
|
| 121 |
+
matched_title_normalized = self._normalize_string(matched_headcategory_title)
|
| 122 |
+
matched_title_lower = matched_headcategory_title.lower()
|
| 123 |
+
|
| 124 |
+
# Try exact normalized match first (most common case)
|
| 125 |
+
for hc in headcategories_data["headcategories"]:
|
| 126 |
+
hc_title = hc.get("title", "")
|
| 127 |
+
if self._normalize_string(hc_title) == matched_title_normalized:
|
| 128 |
+
matched_headcategory = hc
|
| 129 |
+
break
|
| 130 |
+
|
| 131 |
+
# Try partial matching if exact normalized match fails
|
| 132 |
+
if not matched_headcategory:
|
| 133 |
+
for hc in headcategories_data["headcategories"]:
|
| 134 |
+
hc_title = hc.get("title", "").lower()
|
| 135 |
+
if matched_title_lower in hc_title or hc_title in matched_title_lower:
|
| 136 |
+
matched_headcategory = hc
|
| 137 |
+
break
|
| 138 |
+
|
| 139 |
+
if not matched_headcategory:
|
| 140 |
+
available_titles = ", ".join(headcategory_titles[:10])
|
| 141 |
+
raise HTTPException(
|
| 142 |
+
status_code=502,
|
| 143 |
+
detail=(
|
| 144 |
+
f"Could not find matching headcategory for title: '{matched_headcategory_title}'. "
|
| 145 |
+
f"Available headcategories: {available_titles}"
|
| 146 |
+
)
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
headcategory_id = matched_headcategory.get("_id")
|
| 150 |
+
category_ids = matched_headcategory.get("category_ids", [])
|
| 151 |
+
|
| 152 |
+
if not isinstance(headcategory_id, ObjectId):
|
| 153 |
+
raise HTTPException(status_code=500, detail="Invalid headcategory ID format.")
|
| 154 |
+
|
| 155 |
+
if not category_ids:
|
| 156 |
+
raise HTTPException(status_code=500, detail="Selected headcategory has no categories.")
|
| 157 |
+
|
| 158 |
+
# Step 4: Fetch categories from categories collection
|
| 159 |
+
try:
|
| 160 |
+
categories_data = await asyncio.wait_for(
|
| 161 |
+
self._get_categories_by_ids(category_ids), timeout=self._db_timeout_seconds
|
| 162 |
)
|
| 163 |
except asyncio.TimeoutError as exc:
|
| 164 |
raise HTTPException(status_code=504, detail="Timed out loading categories from database.") from exc
|
| 165 |
except Exception as exc:
|
| 166 |
raise HTTPException(status_code=502, detail="Failed to load categories from database.") from exc
|
| 167 |
|
| 168 |
+
if not categories_data or not categories_data.get("categories"):
|
| 169 |
+
raise HTTPException(status_code=500, detail="No categories found for the selected headcategory.")
|
| 170 |
|
| 171 |
+
# Step 5: Use LLM to match notes to a specific category
|
| 172 |
+
category_titles = [cat.get("title", "") for cat in categories_data["categories"]]
|
| 173 |
+
formatted_categories = "\n".join([f"- {title}" for title in category_titles if title])
|
| 174 |
+
|
| 175 |
+
category_prompt = (
|
| 176 |
"Transaction note:\n"
|
| 177 |
f"{notes}\n\n"
|
| 178 |
+
"Available categories:\n"
|
| 179 |
f"{formatted_categories}\n\n"
|
| 180 |
+
"Respond with the exact category title from the list above that best matches this transaction."
|
| 181 |
)
|
| 182 |
|
| 183 |
+
category_request = dict(
|
| 184 |
model=self._model,
|
| 185 |
+
messages=[
|
| 186 |
{
|
| 187 |
"role": "system",
|
| 188 |
"content": (
|
| 189 |
"You classify financial transactions into the closest category. "
|
| 190 |
+
"Only use the provided category title options. "
|
| 191 |
+
"Output valid JSON with key 'title'."
|
| 192 |
),
|
| 193 |
},
|
| 194 |
+
{"role": "user", "content": category_prompt},
|
| 195 |
],
|
| 196 |
)
|
| 197 |
|
| 198 |
try:
|
| 199 |
+
category_response = await asyncio.wait_for(
|
| 200 |
+
self._create_model_response(category_request),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
timeout=self._model_timeout_seconds,
|
| 202 |
)
|
| 203 |
except asyncio.TimeoutError as exc:
|
| 204 |
+
raise HTTPException(status_code=504, detail="Timed out waiting for category model response.") from exc
|
| 205 |
except Exception as exc:
|
| 206 |
+
error_msg = str(exc)
|
| 207 |
+
raise HTTPException(
|
| 208 |
+
status_code=502,
|
| 209 |
+
detail=f"Failed to call the model API for category: {error_msg}"
|
| 210 |
+
) from exc
|
| 211 |
|
| 212 |
try:
|
| 213 |
+
category_payload = self._parse_response_payload(category_response)
|
| 214 |
except ValueError as exc:
|
| 215 |
+
raise HTTPException(status_code=502, detail="Failed to parse category model output.") from exc
|
| 216 |
+
|
| 217 |
+
matched_category_title = category_payload.get("title")
|
| 218 |
+
if not isinstance(matched_category_title, str):
|
| 219 |
+
raise HTTPException(status_code=502, detail="Model response missing category title field.")
|
| 220 |
+
|
| 221 |
+
# Step 6: Find the matched category ID (optimized lookup)
|
| 222 |
+
matched_category = None
|
| 223 |
+
matched_cat_title_normalized = self._normalize_string(matched_category_title)
|
| 224 |
+
matched_cat_title_lower = matched_category_title.lower()
|
| 225 |
+
|
| 226 |
+
# Try exact normalized match first (most common case)
|
| 227 |
+
for cat in categories_data["categories"]:
|
| 228 |
+
cat_title = cat.get("title", "")
|
| 229 |
+
if self._normalize_string(cat_title) == matched_cat_title_normalized:
|
| 230 |
+
matched_category = cat
|
| 231 |
+
break
|
| 232 |
+
|
| 233 |
+
# Try partial matching if exact normalized match fails
|
| 234 |
+
if not matched_category:
|
| 235 |
+
for cat in categories_data["categories"]:
|
| 236 |
+
cat_title = cat.get("title", "").lower()
|
| 237 |
+
if matched_cat_title_lower in cat_title or cat_title in matched_cat_title_lower:
|
| 238 |
+
matched_category = cat
|
| 239 |
+
break
|
| 240 |
+
|
| 241 |
+
if not matched_category:
|
| 242 |
+
available_titles = ", ".join(category_titles[:10])
|
| 243 |
+
raise HTTPException(
|
| 244 |
+
status_code=502,
|
| 245 |
+
detail=(
|
| 246 |
+
f"Could not find matching category for title: '{matched_category_title}'. "
|
| 247 |
+
f"Available categories: {available_titles}"
|
| 248 |
+
)
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
category_id = matched_category.get("_id")
|
| 252 |
+
if not isinstance(category_id, ObjectId):
|
| 253 |
+
raise HTTPException(status_code=500, detail="Invalid category ID format.")
|
| 254 |
|
| 255 |
+
# Get titles from matched objects
|
| 256 |
+
headcategory_title = matched_headcategory.get("title", "")
|
| 257 |
+
category_title = matched_category.get("title", "")
|
|
|
|
| 258 |
|
| 259 |
+
return CategoryPrediction(
|
| 260 |
+
headcategory_id=str(headcategory_id),
|
| 261 |
+
headcategory_title=headcategory_title,
|
| 262 |
+
category_id=str(category_id),
|
| 263 |
+
category_title=category_title,
|
| 264 |
+
)
|
| 265 |
|
| 266 |
def _parse_response_payload(self, response) -> Dict[str, object]:
|
| 267 |
raw_text = self._extract_response_text(response)
|
|
|
|
| 285 |
|
| 286 |
@staticmethod
|
| 287 |
def _extract_response_text(response) -> str:
|
| 288 |
+
"""Extract text from OpenAI API response (supports both Chat Completions and Responses API)."""
|
| 289 |
+
# Try standard Chat Completions API format first
|
| 290 |
+
if hasattr(response, "choices") and response.choices:
|
| 291 |
+
message = response.choices[0].message
|
| 292 |
+
if hasattr(message, "content") and message.content:
|
| 293 |
+
return message.content.strip()
|
| 294 |
+
|
| 295 |
+
# Try Responses API format
|
| 296 |
text = getattr(response, "output_text", "") or ""
|
| 297 |
if isinstance(text, str) and text.strip():
|
| 298 |
return text.strip()
|
|
|
|
| 369 |
|
| 370 |
return None
|
| 371 |
|
| 372 |
+
async def _get_headcategories_cached(self, user_id: str) -> Dict[str, object]:
|
| 373 |
+
"""Fetch headcategories from MongoDB with user-specific caching."""
|
| 374 |
+
async with self._cache_lock:
|
| 375 |
now = time.monotonic()
|
| 376 |
+
# Check cache
|
| 377 |
+
if user_id in self._headcategories_cache:
|
| 378 |
+
cached_data, cached_time = self._headcategories_cache[user_id]
|
| 379 |
+
if (now - cached_time) < self._cache_ttl_seconds:
|
| 380 |
+
return cached_data
|
| 381 |
+
# Cache expired, remove it
|
| 382 |
+
del self._headcategories_cache[user_id]
|
| 383 |
+
|
| 384 |
+
# Fetch from database
|
| 385 |
+
data = await self._get_headcategories(user_id)
|
| 386 |
+
|
| 387 |
+
# Update cache
|
| 388 |
+
async with self._cache_lock:
|
| 389 |
+
self._headcategories_cache[user_id] = (data, time.monotonic())
|
| 390 |
+
|
| 391 |
+
return data
|
| 392 |
+
|
| 393 |
+
async def _get_headcategories(self, user_id: str) -> Dict[str, object]:
|
| 394 |
+
"""Fetch headcategories from MongoDB filtered by user_id."""
|
| 395 |
+
head_collection = self._collection()
|
| 396 |
+
|
| 397 |
+
# Convert user_id string to ObjectId
|
| 398 |
+
try:
|
| 399 |
+
user_object_id = ObjectId(user_id)
|
| 400 |
+
except Exception as exc:
|
| 401 |
+
raise HTTPException(status_code=400, detail=f"Invalid user_id format: {user_id}") from exc
|
| 402 |
+
|
| 403 |
+
# Query headcategories filtered by user_id - only fetch needed fields for performance
|
| 404 |
+
head_docs = await head_collection.find(
|
| 405 |
+
{"user": user_object_id, "categories": {"$type": "array", "$ne": []}},
|
| 406 |
+
{"_id": 1, "title": 1, "categories": 1} # Only fetch needed fields
|
| 407 |
+
).to_list(length=1000)
|
| 408 |
+
|
| 409 |
+
if not head_docs:
|
| 410 |
+
return {"headcategories": []}
|
| 411 |
+
|
| 412 |
+
# Build headcategories structure
|
| 413 |
+
headcategories: List[Dict[str, object]] = []
|
| 414 |
+
for head_doc in head_docs:
|
| 415 |
+
head_id = head_doc.get("_id")
|
| 416 |
+
if not isinstance(head_id, ObjectId):
|
| 417 |
+
continue
|
| 418 |
+
|
| 419 |
+
category_ids = [cid for cid in (head_doc.get("categories") or []) if isinstance(cid, ObjectId)]
|
| 420 |
+
if not category_ids:
|
| 421 |
+
continue
|
| 422 |
+
|
| 423 |
+
headcategories.append({
|
| 424 |
+
"_id": head_id,
|
| 425 |
+
"title": head_doc.get("title", ""),
|
| 426 |
+
"category_ids": category_ids,
|
| 427 |
+
})
|
| 428 |
+
|
| 429 |
+
return {"headcategories": headcategories}
|
| 430 |
+
|
| 431 |
+
async def _get_categories_by_ids(self, category_ids: List[ObjectId]) -> Dict[str, object]:
|
| 432 |
+
"""Fetch categories from MongoDB by their ObjectIds."""
|
| 433 |
+
subcategory_collection = self._subcategory_collection()
|
| 434 |
+
|
| 435 |
+
if not category_ids:
|
| 436 |
+
return {"categories": []}
|
| 437 |
+
|
| 438 |
+
# Query categories collection with the provided ObjectIds
|
| 439 |
+
categories: List[Dict[str, object]] = []
|
| 440 |
+
cursor = subcategory_collection.find(
|
| 441 |
+
{"_id": {"$in": category_ids}},
|
| 442 |
+
{"title": 1, "_id": 1}
|
| 443 |
+
)
|
| 444 |
+
async for cat_doc in cursor:
|
| 445 |
+
cat_id = cat_doc.get("_id")
|
| 446 |
+
if isinstance(cat_id, ObjectId):
|
| 447 |
+
categories.append({
|
| 448 |
+
"_id": cat_id,
|
| 449 |
+
"title": cat_doc.get("title", ""),
|
| 450 |
+
})
|
| 451 |
|
| 452 |
+
return {"categories": categories}
|
|
|
|
|
|
|
| 453 |
|
| 454 |
async def _create_model_response(self, request_payload: Dict[str, object]):
|
| 455 |
+
"""Create a model response using OpenAI Chat Completions API."""
|
| 456 |
+
try:
|
| 457 |
+
return await self._openai_client.chat.completions.create(
|
| 458 |
+
response_format={"type": "json_object"},
|
| 459 |
+
**request_payload,
|
| 460 |
+
)
|
| 461 |
+
except TypeError as exc:
|
| 462 |
+
# Fallback for older openai-python clients or custom API endpoints
|
| 463 |
+
if "responses" in dir(self._openai_client):
|
| 464 |
+
return await self._openai_client.responses.create(
|
| 465 |
+
response_format={"type": "json_object"},
|
| 466 |
+
**request_payload,
|
| 467 |
+
)
|
| 468 |
+
raise
|
| 469 |
|
| 470 |
@staticmethod
|
| 471 |
+
def _format_categories_for_llm(categories: List[Dict[str, object]]) -> str:
|
| 472 |
+
"""Format categories for LLM prompt."""
|
| 473 |
lines = []
|
| 474 |
for category in categories:
|
| 475 |
subs = category.get("subcategories") or []
|
| 476 |
+
subs_text = ", ".join([sub.get("title", "") for sub in subs if isinstance(sub, dict)]) if subs else "Unspecified"
|
| 477 |
lines.append(f"- {category.get('title', 'Unknown')}: {subs_text}")
|
| 478 |
return "\n".join(lines)
|
| 479 |
+
|
| 480 |
+
@staticmethod
|
| 481 |
+
def _normalize_string(s: str) -> str:
|
| 482 |
+
"""Normalize string by removing punctuation and extra spaces for better matching."""
|
| 483 |
+
# Remove punctuation and convert to lowercase
|
| 484 |
+
normalized = s.translate(str.maketrans('', '', string.punctuation)).lower().strip()
|
| 485 |
+
# Replace multiple spaces with single space
|
| 486 |
+
normalized = ' '.join(normalized.split())
|
| 487 |
+
return normalized
|
| 488 |
+
|
| 489 |
+
@staticmethod
|
| 490 |
+
def _find_matching_ids(
|
| 491 |
+
categories: List[Dict[str, object]],
|
| 492 |
+
title: str,
|
| 493 |
+
subcategory: str
|
| 494 |
+
) -> tuple[ObjectId | None, ObjectId | None]:
|
| 495 |
+
"""Find matching headcategory_id and category_id based on title and subcategory strings.
|
| 496 |
+
|
| 497 |
+
Uses flexible matching:
|
| 498 |
+
1. Exact match (case-insensitive)
|
| 499 |
+
2. Normalized match (removes punctuation)
|
| 500 |
+
3. Partial match (one contains the other)
|
| 501 |
+
4. Word-based match (checks if key words match)
|
| 502 |
+
"""
|
| 503 |
+
title_lower = title.strip().lower()
|
| 504 |
+
subcategory_lower = subcategory.strip().lower()
|
| 505 |
+
title_normalized = AutoCategoryService._normalize_string(title)
|
| 506 |
+
subcategory_normalized = AutoCategoryService._normalize_string(subcategory)
|
| 507 |
+
|
| 508 |
+
# First pass: exact match
|
| 509 |
+
for category in categories:
|
| 510 |
+
head_title = category.get("title", "").strip().lower()
|
| 511 |
+
if head_title != title_lower:
|
| 512 |
+
continue
|
| 513 |
+
|
| 514 |
+
subcategories = category.get("subcategories", [])
|
| 515 |
+
for sub in subcategories:
|
| 516 |
+
if isinstance(sub, dict):
|
| 517 |
+
sub_title = sub.get("title", "").strip().lower()
|
| 518 |
+
if sub_title == subcategory_lower:
|
| 519 |
+
headcategory_id = category.get("headcategory_id")
|
| 520 |
+
category_id = sub.get("_id")
|
| 521 |
+
if isinstance(headcategory_id, ObjectId) and isinstance(category_id, ObjectId):
|
| 522 |
+
return headcategory_id, category_id
|
| 523 |
+
|
| 524 |
+
# Second pass: normalized match (removes punctuation, handles "Wage" vs "Wage, Invoices")
|
| 525 |
+
for category in categories:
|
| 526 |
+
head_title = category.get("title", "").strip().lower()
|
| 527 |
+
head_title_norm = AutoCategoryService._normalize_string(head_title)
|
| 528 |
+
if head_title_norm != title_normalized and title_normalized not in head_title_norm and head_title_norm not in title_normalized:
|
| 529 |
+
continue
|
| 530 |
+
|
| 531 |
+
subcategories = category.get("subcategories", [])
|
| 532 |
+
for sub in subcategories:
|
| 533 |
+
if isinstance(sub, dict):
|
| 534 |
+
sub_title = sub.get("title", "").strip().lower()
|
| 535 |
+
sub_title_norm = AutoCategoryService._normalize_string(sub_title)
|
| 536 |
+
if (sub_title_norm == subcategory_normalized or
|
| 537 |
+
subcategory_normalized in sub_title_norm or
|
| 538 |
+
sub_title_norm in subcategory_normalized):
|
| 539 |
+
headcategory_id = category.get("headcategory_id")
|
| 540 |
+
category_id = sub.get("_id")
|
| 541 |
+
if isinstance(headcategory_id, ObjectId) and isinstance(category_id, ObjectId):
|
| 542 |
+
return headcategory_id, category_id
|
| 543 |
+
|
| 544 |
+
# Third pass: partial match (one contains the other)
|
| 545 |
+
for category in categories:
|
| 546 |
+
head_title = category.get("title", "").strip().lower()
|
| 547 |
+
# Check if title matches (exact or contains)
|
| 548 |
+
if title_lower not in head_title and head_title not in title_lower:
|
| 549 |
+
continue
|
| 550 |
+
|
| 551 |
+
subcategories = category.get("subcategories", [])
|
| 552 |
+
for sub in subcategories:
|
| 553 |
+
if isinstance(sub, dict):
|
| 554 |
+
sub_title = sub.get("title", "").strip().lower()
|
| 555 |
+
# Check if subcategory matches (exact or contains)
|
| 556 |
+
if (subcategory_lower in sub_title or sub_title in subcategory_lower or
|
| 557 |
+
subcategory_lower.split()[0] in sub_title or sub_title.split()[0] in subcategory_lower):
|
| 558 |
+
headcategory_id = category.get("headcategory_id")
|
| 559 |
+
category_id = sub.get("_id")
|
| 560 |
+
if isinstance(headcategory_id, ObjectId) and isinstance(category_id, ObjectId):
|
| 561 |
+
return headcategory_id, category_id
|
| 562 |
+
|
| 563 |
+
# Fourth pass: word-based matching (for cases like "Wage" matching "Wage, Invoices")
|
| 564 |
+
title_words = set(title_lower.split())
|
| 565 |
+
subcategory_words = set(subcategory_lower.split())
|
| 566 |
+
|
| 567 |
+
for category in categories:
|
| 568 |
+
head_title = category.get("title", "").strip().lower()
|
| 569 |
+
head_title_words = set(head_title.split())
|
| 570 |
+
|
| 571 |
+
# Check if there's significant word overlap for title
|
| 572 |
+
if not title_words.intersection(head_title_words) and not head_title_words.intersection(title_words):
|
| 573 |
+
continue
|
| 574 |
+
|
| 575 |
+
subcategories = category.get("subcategories", [])
|
| 576 |
+
for sub in subcategories:
|
| 577 |
+
if isinstance(sub, dict):
|
| 578 |
+
sub_title = sub.get("title", "").strip().lower()
|
| 579 |
+
sub_title_words = set(sub_title.split())
|
| 580 |
+
|
| 581 |
+
# Check if there's significant word overlap for subcategory
|
| 582 |
+
if (subcategory_words.intersection(sub_title_words) or
|
| 583 |
+
sub_title_words.intersection(subcategory_words)):
|
| 584 |
+
headcategory_id = category.get("headcategory_id")
|
| 585 |
+
category_id = sub.get("_id")
|
| 586 |
+
if isinstance(headcategory_id, ObjectId) and isinstance(category_id, ObjectId):
|
| 587 |
+
return headcategory_id, category_id
|
| 588 |
+
|
| 589 |
+
return None, None
|