|
|
from __future__ import annotations |
|
|
|
|
|
import ast |
|
|
import asyncio |
|
|
import json |
|
|
import re |
|
|
import string |
|
|
import time |
|
|
from typing import Callable, Dict, List, Optional, Tuple |
|
|
|
|
|
from bson import ObjectId |
|
|
from fastapi import HTTPException |
|
|
from motor.motor_asyncio import AsyncIOMotorCollection |
|
|
from openai import AsyncOpenAI |
|
|
|
|
|
from app.core.config import settings |
|
|
from app.schemas.categories import CategoryPrediction |
|
|
|
|
|
|
|
|
class AutoCategoryService: |
|
|
"""Classifies transaction notes into the closest Mongo-backed category.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
collection_getter: Callable[[], AsyncIOMotorCollection], |
|
|
subcategory_collection_getter: Callable[[], AsyncIOMotorCollection], |
|
|
openai_client: AsyncOpenAI, |
|
|
model: str, |
|
|
cache_ttl_seconds: int, |
|
|
db_timeout_seconds: float, |
|
|
model_timeout_seconds: float, |
|
|
) -> None: |
|
|
self._collection_getter = collection_getter |
|
|
self._subcategory_collection_getter = subcategory_collection_getter |
|
|
self._openai_client = openai_client |
|
|
self._model = model |
|
|
self._cache_ttl_seconds = cache_ttl_seconds |
|
|
self._db_timeout_seconds = db_timeout_seconds |
|
|
self._model_timeout_seconds = model_timeout_seconds |
|
|
|
|
|
|
|
|
self._headcategories_cache: Dict[str, Tuple[Dict[str, object], float]] = {} |
|
|
self._cache_lock = asyncio.Lock() |
|
|
|
|
|
def _collection(self) -> AsyncIOMotorCollection: |
|
|
return self._collection_getter() |
|
|
|
|
|
def _subcategory_collection(self) -> AsyncIOMotorCollection: |
|
|
return self._subcategory_collection_getter() |
|
|
|
|
|
async def categorize(self, notes: str, user_id: str) -> CategoryPrediction: |
|
|
"""Categorize transaction notes using a two-step approach: |
|
|
1. First match notes to a headcategory title |
|
|
2. Then match notes to a category within that headcategory |
|
|
""" |
|
|
|
|
|
try: |
|
|
headcategories_data = await asyncio.wait_for( |
|
|
self._get_headcategories_cached(user_id), timeout=self._db_timeout_seconds |
|
|
) |
|
|
except asyncio.TimeoutError as exc: |
|
|
raise HTTPException(status_code=504, detail="Timed out loading headcategories from database.") from exc |
|
|
except Exception as exc: |
|
|
raise HTTPException(status_code=502, detail="Failed to load headcategories from database.") from exc |
|
|
|
|
|
if not headcategories_data or not headcategories_data.get("headcategories"): |
|
|
raise HTTPException(status_code=500, detail="No headcategories configured for this user.") |
|
|
|
|
|
|
|
|
headcategory_titles = [hc.get("title", "") for hc in headcategories_data["headcategories"]] |
|
|
formatted_headcategories = "\n".join([f"- {title}" for title in headcategory_titles if title]) |
|
|
|
|
|
headcategory_prompt = ( |
|
|
"Transaction note:\n" |
|
|
f"{notes}\n\n" |
|
|
"Available headcategories:\n" |
|
|
f"{formatted_headcategories}\n\n" |
|
|
"Respond with the exact headcategory title from the list above that best matches this transaction." |
|
|
) |
|
|
|
|
|
headcategory_request = dict( |
|
|
model=self._model, |
|
|
messages=[ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": ( |
|
|
"You classify financial transactions into the closest headcategory. " |
|
|
"Only use the provided headcategory title options. " |
|
|
"Output valid JSON with key 'title'." |
|
|
), |
|
|
}, |
|
|
{"role": "user", "content": headcategory_prompt}, |
|
|
], |
|
|
) |
|
|
|
|
|
try: |
|
|
headcategory_response = await asyncio.wait_for( |
|
|
self._create_model_response(headcategory_request), |
|
|
timeout=self._model_timeout_seconds, |
|
|
) |
|
|
except asyncio.TimeoutError as exc: |
|
|
raise HTTPException(status_code=504, detail="Timed out waiting for headcategory model response.") from exc |
|
|
except Exception as exc: |
|
|
error_msg = str(exc) |
|
|
raise HTTPException( |
|
|
status_code=502, |
|
|
detail=f"Failed to call the model API for headcategory: {error_msg}" |
|
|
) from exc |
|
|
|
|
|
try: |
|
|
headcategory_payload = self._parse_response_payload(headcategory_response) |
|
|
except ValueError as exc: |
|
|
raise HTTPException(status_code=502, detail="Failed to parse headcategory model output.") from exc |
|
|
|
|
|
matched_headcategory_title = headcategory_payload.get("title") |
|
|
if not isinstance(matched_headcategory_title, str): |
|
|
raise HTTPException(status_code=502, detail="Model response missing headcategory title field.") |
|
|
|
|
|
|
|
|
matched_headcategory = None |
|
|
matched_title_normalized = self._normalize_string(matched_headcategory_title) |
|
|
matched_title_lower = matched_headcategory_title.lower() |
|
|
|
|
|
|
|
|
for hc in headcategories_data["headcategories"]: |
|
|
hc_title = hc.get("title", "") |
|
|
if self._normalize_string(hc_title) == matched_title_normalized: |
|
|
matched_headcategory = hc |
|
|
break |
|
|
|
|
|
|
|
|
if not matched_headcategory: |
|
|
for hc in headcategories_data["headcategories"]: |
|
|
hc_title = hc.get("title", "").lower() |
|
|
if matched_title_lower in hc_title or hc_title in matched_title_lower: |
|
|
matched_headcategory = hc |
|
|
break |
|
|
|
|
|
if not matched_headcategory: |
|
|
available_titles = ", ".join(headcategory_titles[:10]) |
|
|
raise HTTPException( |
|
|
status_code=502, |
|
|
detail=( |
|
|
f"Could not find matching headcategory for title: '{matched_headcategory_title}'. " |
|
|
f"Available headcategories: {available_titles}" |
|
|
) |
|
|
) |
|
|
|
|
|
headcategory_id = matched_headcategory.get("_id") |
|
|
category_ids = matched_headcategory.get("category_ids", []) |
|
|
|
|
|
if not isinstance(headcategory_id, ObjectId): |
|
|
raise HTTPException(status_code=500, detail="Invalid headcategory ID format.") |
|
|
|
|
|
if not category_ids: |
|
|
raise HTTPException(status_code=500, detail="Selected headcategory has no categories.") |
|
|
|
|
|
|
|
|
try: |
|
|
categories_data = await asyncio.wait_for( |
|
|
self._get_categories_by_ids(category_ids), timeout=self._db_timeout_seconds |
|
|
) |
|
|
except asyncio.TimeoutError as exc: |
|
|
raise HTTPException(status_code=504, detail="Timed out loading categories from database.") from exc |
|
|
except Exception as exc: |
|
|
raise HTTPException(status_code=502, detail="Failed to load categories from database.") from exc |
|
|
|
|
|
if not categories_data or not categories_data.get("categories"): |
|
|
raise HTTPException(status_code=500, detail="No categories found for the selected headcategory.") |
|
|
|
|
|
|
|
|
category_titles = [cat.get("title", "") for cat in categories_data["categories"]] |
|
|
formatted_categories = "\n".join([f"- {title}" for title in category_titles if title]) |
|
|
|
|
|
category_prompt = ( |
|
|
"Transaction note:\n" |
|
|
f"{notes}\n\n" |
|
|
"Available categories:\n" |
|
|
f"{formatted_categories}\n\n" |
|
|
"Respond with the exact category title from the list above that best matches this transaction." |
|
|
) |
|
|
|
|
|
category_request = dict( |
|
|
model=self._model, |
|
|
messages=[ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": ( |
|
|
"You classify financial transactions into the closest category. " |
|
|
"Only use the provided category title options. " |
|
|
"Output valid JSON with key 'title'." |
|
|
), |
|
|
}, |
|
|
{"role": "user", "content": category_prompt}, |
|
|
], |
|
|
) |
|
|
|
|
|
try: |
|
|
category_response = await asyncio.wait_for( |
|
|
self._create_model_response(category_request), |
|
|
timeout=self._model_timeout_seconds, |
|
|
) |
|
|
except asyncio.TimeoutError as exc: |
|
|
raise HTTPException(status_code=504, detail="Timed out waiting for category model response.") from exc |
|
|
except Exception as exc: |
|
|
error_msg = str(exc) |
|
|
raise HTTPException( |
|
|
status_code=502, |
|
|
detail=f"Failed to call the model API for category: {error_msg}" |
|
|
) from exc |
|
|
|
|
|
try: |
|
|
category_payload = self._parse_response_payload(category_response) |
|
|
except ValueError as exc: |
|
|
raise HTTPException(status_code=502, detail="Failed to parse category model output.") from exc |
|
|
|
|
|
matched_category_title = category_payload.get("title") |
|
|
if not isinstance(matched_category_title, str): |
|
|
raise HTTPException(status_code=502, detail="Model response missing category title field.") |
|
|
|
|
|
|
|
|
matched_category = None |
|
|
matched_cat_title_normalized = self._normalize_string(matched_category_title) |
|
|
matched_cat_title_lower = matched_category_title.lower() |
|
|
|
|
|
|
|
|
for cat in categories_data["categories"]: |
|
|
cat_title = cat.get("title", "") |
|
|
if self._normalize_string(cat_title) == matched_cat_title_normalized: |
|
|
matched_category = cat |
|
|
break |
|
|
|
|
|
|
|
|
if not matched_category: |
|
|
for cat in categories_data["categories"]: |
|
|
cat_title = cat.get("title", "").lower() |
|
|
if matched_cat_title_lower in cat_title or cat_title in matched_cat_title_lower: |
|
|
matched_category = cat |
|
|
break |
|
|
|
|
|
if not matched_category: |
|
|
available_titles = ", ".join(category_titles[:10]) |
|
|
raise HTTPException( |
|
|
status_code=502, |
|
|
detail=( |
|
|
f"Could not find matching category for title: '{matched_category_title}'. " |
|
|
f"Available categories: {available_titles}" |
|
|
) |
|
|
) |
|
|
|
|
|
category_id = matched_category.get("_id") |
|
|
if not isinstance(category_id, ObjectId): |
|
|
raise HTTPException(status_code=500, detail="Invalid category ID format.") |
|
|
|
|
|
|
|
|
headcategory_title = matched_headcategory.get("title", "") |
|
|
category_title = matched_category.get("title", "") |
|
|
|
|
|
return CategoryPrediction( |
|
|
headcategory_id=str(headcategory_id), |
|
|
headcategory_title=headcategory_title, |
|
|
category_id=str(category_id), |
|
|
category_title=category_title, |
|
|
) |
|
|
|
|
|
def _parse_response_payload(self, response) -> Dict[str, object]: |
|
|
raw_text = self._extract_response_text(response) |
|
|
if not raw_text: |
|
|
raise ValueError("Model response did not contain text content.") |
|
|
|
|
|
cleaned = self._strip_code_fence(raw_text) |
|
|
candidates = [cleaned] |
|
|
|
|
|
json_snippet = self._extract_json_snippet(cleaned) |
|
|
if json_snippet and json_snippet not in candidates: |
|
|
candidates.append(json_snippet) |
|
|
|
|
|
for candidate in candidates: |
|
|
for parser in (self._try_parse_json, self._try_parse_literal_dict, self._try_parse_key_values): |
|
|
payload = parser(candidate) |
|
|
if payload: |
|
|
return payload |
|
|
|
|
|
raise ValueError("Unable to coerce model response into a payload.") |
|
|
|
|
|
@staticmethod |
|
|
def _extract_response_text(response) -> str: |
|
|
"""Extract text from OpenAI API response (supports both Chat Completions and Responses API).""" |
|
|
|
|
|
if hasattr(response, "choices") and response.choices: |
|
|
message = response.choices[0].message |
|
|
if hasattr(message, "content") and message.content: |
|
|
return message.content.strip() |
|
|
|
|
|
|
|
|
text = getattr(response, "output_text", "") or "" |
|
|
if isinstance(text, str) and text.strip(): |
|
|
return text.strip() |
|
|
|
|
|
outputs = getattr(response, "output", []) or [] |
|
|
for output in outputs: |
|
|
contents = getattr(output, "content", []) or [] |
|
|
for content in contents: |
|
|
value = getattr(content, "text", None) |
|
|
if isinstance(value, str) and value.strip(): |
|
|
return value.strip() |
|
|
|
|
|
return "" |
|
|
|
|
|
@staticmethod |
|
|
def _strip_code_fence(raw_text: str) -> str: |
|
|
text = raw_text.strip() |
|
|
if text.startswith("```") and text.endswith("```"): |
|
|
lines = text.split("\n") |
|
|
|
|
|
if len(lines) >= 2: |
|
|
text = "\n".join(lines[1:-1]).strip() |
|
|
return text |
|
|
|
|
|
@staticmethod |
|
|
def _extract_json_snippet(raw_text: str) -> Optional[str]: |
|
|
start = raw_text.find("{") |
|
|
end = raw_text.rfind("}") |
|
|
if start == -1 or end == -1 or end <= start: |
|
|
return None |
|
|
return raw_text[start : end + 1] |
|
|
|
|
|
@staticmethod |
|
|
def _try_parse_json(raw_text: str) -> Optional[Dict[str, object]]: |
|
|
text = raw_text.strip() |
|
|
if not text: |
|
|
return None |
|
|
try: |
|
|
payload = json.loads(text) |
|
|
except json.JSONDecodeError: |
|
|
return None |
|
|
return payload if isinstance(payload, dict) else None |
|
|
|
|
|
@staticmethod |
|
|
def _try_parse_literal_dict(raw_text: str) -> Optional[Dict[str, object]]: |
|
|
try: |
|
|
payload = ast.literal_eval(raw_text) |
|
|
except (SyntaxError, ValueError): |
|
|
return None |
|
|
return payload if isinstance(payload, dict) else None |
|
|
|
|
|
@staticmethod |
|
|
def _try_parse_key_values(raw_text: str) -> Optional[Dict[str, object]]: |
|
|
title: Optional[str] = None |
|
|
subcategory: Optional[str] = None |
|
|
for chunk in re.split(r"[\n;,]+", raw_text): |
|
|
if ":" in chunk: |
|
|
key, value = chunk.split(":", 1) |
|
|
elif "=" in chunk: |
|
|
key, value = chunk.split("=", 1) |
|
|
else: |
|
|
continue |
|
|
key_normalized = key.strip().lower() |
|
|
value_clean = value.strip().strip('"\'') |
|
|
if not value_clean: |
|
|
continue |
|
|
if key_normalized in {"title", "category"}: |
|
|
title = value_clean |
|
|
elif key_normalized in {"subcategory", "sub_category", "sub"}: |
|
|
subcategory = value_clean |
|
|
|
|
|
if title and subcategory: |
|
|
return {"title": title, "subcategory": subcategory} |
|
|
|
|
|
return None |
|
|
|
|
|
async def _get_headcategories_cached(self, user_id: str) -> Dict[str, object]: |
|
|
"""Fetch headcategories from MongoDB with user-specific caching.""" |
|
|
async with self._cache_lock: |
|
|
now = time.monotonic() |
|
|
|
|
|
if user_id in self._headcategories_cache: |
|
|
cached_data, cached_time = self._headcategories_cache[user_id] |
|
|
if (now - cached_time) < self._cache_ttl_seconds: |
|
|
return cached_data |
|
|
|
|
|
del self._headcategories_cache[user_id] |
|
|
|
|
|
|
|
|
data = await self._get_headcategories(user_id) |
|
|
|
|
|
|
|
|
async with self._cache_lock: |
|
|
self._headcategories_cache[user_id] = (data, time.monotonic()) |
|
|
|
|
|
return data |
|
|
|
|
|
async def _get_headcategories(self, user_id: str) -> Dict[str, object]: |
|
|
"""Fetch headcategories from MongoDB filtered by user_id.""" |
|
|
head_collection = self._collection() |
|
|
|
|
|
|
|
|
try: |
|
|
user_object_id = ObjectId(user_id) |
|
|
except Exception as exc: |
|
|
raise HTTPException(status_code=400, detail=f"Invalid user_id format: {user_id}") from exc |
|
|
|
|
|
|
|
|
head_docs = await head_collection.find( |
|
|
{"user": user_object_id, "categories": {"$type": "array", "$ne": []}}, |
|
|
{"_id": 1, "title": 1, "categories": 1} |
|
|
).to_list(length=1000) |
|
|
|
|
|
if not head_docs: |
|
|
return {"headcategories": []} |
|
|
|
|
|
|
|
|
headcategories: List[Dict[str, object]] = [] |
|
|
for head_doc in head_docs: |
|
|
head_id = head_doc.get("_id") |
|
|
if not isinstance(head_id, ObjectId): |
|
|
continue |
|
|
|
|
|
category_ids = [cid for cid in (head_doc.get("categories") or []) if isinstance(cid, ObjectId)] |
|
|
if not category_ids: |
|
|
continue |
|
|
|
|
|
headcategories.append({ |
|
|
"_id": head_id, |
|
|
"title": head_doc.get("title", ""), |
|
|
"category_ids": category_ids, |
|
|
}) |
|
|
|
|
|
return {"headcategories": headcategories} |
|
|
|
|
|
async def _get_categories_by_ids(self, category_ids: List[ObjectId]) -> Dict[str, object]: |
|
|
"""Fetch categories from MongoDB by their ObjectIds.""" |
|
|
subcategory_collection = self._subcategory_collection() |
|
|
|
|
|
if not category_ids: |
|
|
return {"categories": []} |
|
|
|
|
|
|
|
|
categories: List[Dict[str, object]] = [] |
|
|
cursor = subcategory_collection.find( |
|
|
{"_id": {"$in": category_ids}}, |
|
|
{"title": 1, "_id": 1} |
|
|
) |
|
|
async for cat_doc in cursor: |
|
|
cat_id = cat_doc.get("_id") |
|
|
if isinstance(cat_id, ObjectId): |
|
|
categories.append({ |
|
|
"_id": cat_id, |
|
|
"title": cat_doc.get("title", ""), |
|
|
}) |
|
|
|
|
|
return {"categories": categories} |
|
|
|
|
|
async def _create_model_response(self, request_payload: Dict[str, object]): |
|
|
"""Create a model response using OpenAI Chat Completions API.""" |
|
|
try: |
|
|
return await self._openai_client.chat.completions.create( |
|
|
response_format={"type": "json_object"}, |
|
|
**request_payload, |
|
|
) |
|
|
except TypeError as exc: |
|
|
|
|
|
if "responses" in dir(self._openai_client): |
|
|
return await self._openai_client.responses.create( |
|
|
response_format={"type": "json_object"}, |
|
|
**request_payload, |
|
|
) |
|
|
raise |
|
|
|
|
|
@staticmethod |
|
|
def _format_categories_for_llm(categories: List[Dict[str, object]]) -> str: |
|
|
"""Format categories for LLM prompt.""" |
|
|
lines = [] |
|
|
for category in categories: |
|
|
subs = category.get("subcategories") or [] |
|
|
subs_text = ", ".join([sub.get("title", "") for sub in subs if isinstance(sub, dict)]) if subs else "Unspecified" |
|
|
lines.append(f"- {category.get('title', 'Unknown')}: {subs_text}") |
|
|
return "\n".join(lines) |
|
|
|
|
|
@staticmethod |
|
|
def _normalize_string(s: str) -> str: |
|
|
"""Normalize string by removing punctuation and extra spaces for better matching.""" |
|
|
|
|
|
normalized = s.translate(str.maketrans('', '', string.punctuation)).lower().strip() |
|
|
|
|
|
normalized = ' '.join(normalized.split()) |
|
|
return normalized |
|
|
|
|
|
@staticmethod |
|
|
def _find_matching_ids( |
|
|
categories: List[Dict[str, object]], |
|
|
title: str, |
|
|
subcategory: str |
|
|
) -> tuple[ObjectId | None, ObjectId | None]: |
|
|
"""Find matching headcategory_id and category_id based on title and subcategory strings. |
|
|
|
|
|
Uses flexible matching: |
|
|
1. Exact match (case-insensitive) |
|
|
2. Normalized match (removes punctuation) |
|
|
3. Partial match (one contains the other) |
|
|
4. Word-based match (checks if key words match) |
|
|
""" |
|
|
title_lower = title.strip().lower() |
|
|
subcategory_lower = subcategory.strip().lower() |
|
|
title_normalized = AutoCategoryService._normalize_string(title) |
|
|
subcategory_normalized = AutoCategoryService._normalize_string(subcategory) |
|
|
|
|
|
|
|
|
for category in categories: |
|
|
head_title = category.get("title", "").strip().lower() |
|
|
if head_title != title_lower: |
|
|
continue |
|
|
|
|
|
subcategories = category.get("subcategories", []) |
|
|
for sub in subcategories: |
|
|
if isinstance(sub, dict): |
|
|
sub_title = sub.get("title", "").strip().lower() |
|
|
if sub_title == subcategory_lower: |
|
|
headcategory_id = category.get("headcategory_id") |
|
|
category_id = sub.get("_id") |
|
|
if isinstance(headcategory_id, ObjectId) and isinstance(category_id, ObjectId): |
|
|
return headcategory_id, category_id |
|
|
|
|
|
|
|
|
for category in categories: |
|
|
head_title = category.get("title", "").strip().lower() |
|
|
head_title_norm = AutoCategoryService._normalize_string(head_title) |
|
|
if head_title_norm != title_normalized and title_normalized not in head_title_norm and head_title_norm not in title_normalized: |
|
|
continue |
|
|
|
|
|
subcategories = category.get("subcategories", []) |
|
|
for sub in subcategories: |
|
|
if isinstance(sub, dict): |
|
|
sub_title = sub.get("title", "").strip().lower() |
|
|
sub_title_norm = AutoCategoryService._normalize_string(sub_title) |
|
|
if (sub_title_norm == subcategory_normalized or |
|
|
subcategory_normalized in sub_title_norm or |
|
|
sub_title_norm in subcategory_normalized): |
|
|
headcategory_id = category.get("headcategory_id") |
|
|
category_id = sub.get("_id") |
|
|
if isinstance(headcategory_id, ObjectId) and isinstance(category_id, ObjectId): |
|
|
return headcategory_id, category_id |
|
|
|
|
|
|
|
|
for category in categories: |
|
|
head_title = category.get("title", "").strip().lower() |
|
|
|
|
|
if title_lower not in head_title and head_title not in title_lower: |
|
|
continue |
|
|
|
|
|
subcategories = category.get("subcategories", []) |
|
|
for sub in subcategories: |
|
|
if isinstance(sub, dict): |
|
|
sub_title = sub.get("title", "").strip().lower() |
|
|
|
|
|
if (subcategory_lower in sub_title or sub_title in subcategory_lower or |
|
|
subcategory_lower.split()[0] in sub_title or sub_title.split()[0] in subcategory_lower): |
|
|
headcategory_id = category.get("headcategory_id") |
|
|
category_id = sub.get("_id") |
|
|
if isinstance(headcategory_id, ObjectId) and isinstance(category_id, ObjectId): |
|
|
return headcategory_id, category_id |
|
|
|
|
|
|
|
|
title_words = set(title_lower.split()) |
|
|
subcategory_words = set(subcategory_lower.split()) |
|
|
|
|
|
for category in categories: |
|
|
head_title = category.get("title", "").strip().lower() |
|
|
head_title_words = set(head_title.split()) |
|
|
|
|
|
|
|
|
if not title_words.intersection(head_title_words) and not head_title_words.intersection(title_words): |
|
|
continue |
|
|
|
|
|
subcategories = category.get("subcategories", []) |
|
|
for sub in subcategories: |
|
|
if isinstance(sub, dict): |
|
|
sub_title = sub.get("title", "").strip().lower() |
|
|
sub_title_words = set(sub_title.split()) |
|
|
|
|
|
|
|
|
if (subcategory_words.intersection(sub_title_words) or |
|
|
sub_title_words.intersection(subcategory_words)): |
|
|
headcategory_id = category.get("headcategory_id") |
|
|
category_id = sub.get("_id") |
|
|
if isinstance(headcategory_id, ObjectId) and isinstance(category_id, ObjectId): |
|
|
return headcategory_id, category_id |
|
|
|
|
|
return None, None |
|
|
|