HariLogicgo commited on
Commit
8cceab7
·
1 Parent(s): 790aee0

added id returning

Browse files
app/api/routes.py CHANGED
@@ -20,7 +20,7 @@ async def categorize_transaction(
20
  ) -> CategorizeResponse | JSONResponse:
21
  started_at = time.monotonic()
22
  try:
23
- result = await service.categorize(payload.notes)
24
  await api_logger.log_categorization(
25
  name="Auto Expense Categorization",
26
  status="success",
 
20
  ) -> CategorizeResponse | JSONResponse:
21
  started_at = time.monotonic()
22
  try:
23
+ result = await service.categorize(payload.notes, payload.user_id)
24
  await api_logger.log_categorization(
25
  name="Auto Expense Categorization",
26
  status="success",
app/core/config.py CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  from functools import lru_cache
2
 
3
  from pydantic import Field
@@ -5,24 +11,86 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
5
 
6
 
7
  class Settings(BaseSettings):
8
- model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
9
 
10
- mongo_uri: str = Field(..., alias="MONGO_URI")
11
- mongo_db: str = Field("expense", alias="MONGO_DB")
12
- mongo_collection: str = Field("headcategories", alias="MONGO_COLLECTION")
13
- mongo_subcategory_collection: str = Field("categories", alias="MONGO_SUBCATEGORY_COLLECTION")
14
- api_logs_collection: str = Field("api_logs", alias="MONGO_API_LOGS_COLLECTION")
 
15
 
16
- openai_api_key: str = Field(..., alias="OPENAI_API_KEY")
17
- openai_model: str = Field("gpt-4o-mini", alias="OPENAI_MODEL")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- category_cache_ttl_seconds: int = Field(300, alias="CATEGORY_CACHE_TTL")
20
- use_static_categories: bool = Field(True, alias="USE_STATIC_CATEGORIES")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  @lru_cache
24
  def get_settings() -> Settings:
 
 
 
 
 
25
  return Settings()
26
 
27
 
 
28
  settings = get_settings()
 
1
+ """Application configuration settings.
2
+
3
+ This module handles all configuration settings loaded from environment variables.
4
+ Settings are validated using Pydantic and cached for performance.
5
+ """
6
+
7
  from functools import lru_cache
8
 
9
  from pydantic import Field
 
11
 
12
 
13
  class Settings(BaseSettings):
14
+ """Application settings loaded from environment variables."""
15
 
16
+ model_config = SettingsConfigDict(
17
+ env_file=".env",
18
+ env_file_encoding="utf-8",
19
+ extra="ignore",
20
+ case_sensitive=False,
21
+ )
22
 
23
+ # MongoDB Configuration
24
+ mongo_uri: str = Field(
25
+ ...,
26
+ alias="MONGO_URI",
27
+ description="MongoDB connection URI",
28
+ )
29
+ mongo_db: str = Field(
30
+ "expense",
31
+ alias="MONGO_DB",
32
+ description="MongoDB database name",
33
+ )
34
+ mongo_collection: str = Field(
35
+ "headcategories",
36
+ alias="MONGO_COLLECTION",
37
+ description="MongoDB collection name for headcategories",
38
+ )
39
+ mongo_subcategory_collection: str = Field(
40
+ "categories",
41
+ alias="MONGO_SUBCATEGORY_COLLECTION",
42
+ description="MongoDB collection name for categories",
43
+ )
44
+ api_logs_collection: str = Field(
45
+ "api_logs",
46
+ alias="MONGO_API_LOGS_COLLECTION",
47
+ description="MongoDB collection name for API logs",
48
+ )
49
 
50
+ # OpenAI Configuration
51
+ openai_api_key: str = Field(
52
+ ...,
53
+ alias="OPENAI_API_KEY",
54
+ description="OpenAI API key for LLM requests",
55
+ )
56
+ openai_model: str = Field(
57
+ "gpt-4o-mini",
58
+ alias="OPENAI_MODEL",
59
+ description="OpenAI model to use for categorization",
60
+ )
61
+
62
+ # Performance & Caching Configuration
63
+ category_cache_ttl_seconds: int = Field(
64
+ 300,
65
+ alias="CATEGORY_CACHE_TTL",
66
+ description="Time-to-live for category cache in seconds (5 minutes default)",
67
+ ge=60, # Minimum 1 minute
68
+ )
69
+ db_query_timeout_seconds: float = Field(
70
+ 5.0,
71
+ alias="DB_QUERY_TIMEOUT",
72
+ description="Timeout for database queries in seconds",
73
+ ge=1.0,
74
+ le=30.0,
75
+ )
76
+ model_api_timeout_seconds: float = Field(
77
+ 15.0,
78
+ alias="MODEL_API_TIMEOUT",
79
+ description="Timeout for OpenAI API calls in seconds",
80
+ ge=5.0,
81
+ le=60.0,
82
+ )
83
 
84
 
85
  @lru_cache
86
  def get_settings() -> Settings:
87
+ """Get cached settings instance.
88
+
89
+ Returns:
90
+ Settings: Application settings instance
91
+ """
92
  return Settings()
93
 
94
 
95
+ # Global settings instance
96
  settings = get_settings()
app/dependencies.py CHANGED
@@ -15,6 +15,8 @@ def _get_service() -> AutoCategoryService:
15
  openai_client=openai_client,
16
  model=settings.openai_model,
17
  cache_ttl_seconds=settings.category_cache_ttl_seconds,
 
 
18
  )
19
 
20
 
 
15
  openai_client=openai_client,
16
  model=settings.openai_model,
17
  cache_ttl_seconds=settings.category_cache_ttl_seconds,
18
+ db_timeout_seconds=settings.db_query_timeout_seconds,
19
+ model_timeout_seconds=settings.model_api_timeout_seconds,
20
  )
21
 
22
 
app/schemas/categories.py CHANGED
@@ -2,19 +2,28 @@ from __future__ import annotations
2
 
3
  from typing import Optional
4
 
5
- from pydantic import BaseModel, Field
 
6
 
7
 
8
  class CategorizeRequest(BaseModel):
9
  notes: str = Field(..., min_length=1, description="Full transaction note.")
10
- user_id: Optional[str] = Field(
11
- None, description="Optional user identifier associated with the request."
12
- )
13
 
14
 
15
  class CategoryPrediction(BaseModel):
16
- title: str = Field(..., description="High-level category title.")
17
- subcategory: str = Field(..., description="Specific subcategory chosen by the model.")
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  class CategorizeResponse(BaseModel):
 
2
 
3
  from typing import Optional
4
 
5
+ from bson import ObjectId
6
+ from pydantic import BaseModel, Field, field_validator
7
 
8
 
9
  class CategorizeRequest(BaseModel):
10
  notes: str = Field(..., min_length=1, description="Full transaction note.")
11
+ user_id: str = Field(..., description="User identifier associated with the request.")
 
 
12
 
13
 
14
  class CategoryPrediction(BaseModel):
15
+ headcategory_id: str = Field(..., description="High-level category ObjectId.")
16
+ headcategory_title: str = Field(..., description="High-level category title.")
17
+ category_id: str = Field(..., description="Specific subcategory ObjectId chosen by the model.")
18
+ category_title: str = Field(..., description="Specific subcategory title chosen by the model.")
19
+
20
+ @field_validator('headcategory_id', 'category_id')
21
+ @classmethod
22
+ def validate_object_id(cls, v: str) -> str:
23
+ """Validate that the string is a valid ObjectId."""
24
+ if not ObjectId.is_valid(v):
25
+ raise ValueError(f"Invalid ObjectId: {v}")
26
+ return v
27
 
28
 
29
  class CategorizeResponse(BaseModel):
app/services/autocategorizer.py CHANGED
@@ -4,8 +4,9 @@ import ast
4
  import asyncio
5
  import json
6
  import re
 
7
  import time
8
- from typing import Callable, Dict, List, Optional
9
 
10
  from bson import ObjectId
11
  from fastapi import HTTPException
@@ -19,127 +20,6 @@ from app.schemas.categories import CategoryPrediction
19
  class AutoCategoryService:
20
  """Classifies transaction notes into the closest Mongo-backed category."""
21
 
22
- # Curated categories requested by the client. When enabled via settings.use_static_categories,
23
- # we bypass Mongo reads to avoid noisy data and long scans.
24
- _STATIC_CATEGORIES: List[Dict[str, object]] = [
25
- {
26
- "title": "Food & Drinks",
27
- "subcategories": ["Groceries", "Restaurant, Fast - Food", "Bar, Cafe", "Food & Drink"],
28
- },
29
- {
30
- "title": "Investments",
31
- "subcategories": [
32
- "Investments",
33
- "Realty",
34
- "Vehicles, Chattels",
35
- "Finacial investments",
36
- "Savings",
37
- "Collections",
38
- ],
39
- },
40
- {
41
- "title": "Communication,PC",
42
- "subcategories": ["Communication,PC", "Phone", "Internet", "Software, app, games", "Postal services"],
43
- },
44
- {
45
- "title": "Financial Expenses",
46
- "subcategories": [
47
- "Financial expenses",
48
- "Taxes",
49
- "Insurances",
50
- "Loan, interests",
51
- "Fines",
52
- "Advisory",
53
- "Charges, Fees",
54
- "Child Support",
55
- ],
56
- },
57
- {
58
- "title": "Life & Entertainment",
59
- "subcategories": [
60
- "Life & Entertainment",
61
- "Health, Care, Doctor",
62
- "Wellness, Beauty",
63
- "Active sport, Fitness",
64
- "Culture, sport events",
65
- "Life events",
66
- "Hobbies",
67
- "Education, Development",
68
- "Books, Audio, subscription",
69
- "TV, Streaming",
70
- "Holiday, Trip, Hotels",
71
- "Charity, Gifts",
72
- "Alcohol, Tobacco",
73
- "Lottery, Gamblings",
74
- ],
75
- },
76
- {
77
- "title": "Vehicle",
78
- "subcategories": [
79
- "Vehicle",
80
- "Fuel",
81
- "Parking",
82
- "Vehicle maintenance",
83
- "Rentals",
84
- "Vehicle insurance",
85
- "Leasing",
86
- ],
87
- },
88
- {
89
- "title": "Transportation",
90
- "subcategories": ["Transportation", "Public transport", "Taxi", "Long distance", "Business trips"],
91
- },
92
- {
93
- "title": "Housing",
94
- "subcategories": [
95
- "Housing",
96
- "Rent",
97
- "Mortgage",
98
- "Energy, utilities",
99
- "Services",
100
- "Maintenance, repairs",
101
- "Property insurance",
102
- ],
103
- },
104
- {
105
- "title": "Shopping",
106
- "subcategories": [
107
- "Shopping",
108
- "Clothes & shoes",
109
- "Jewels & Accessories",
110
- "Health & Beauty",
111
- "Kids",
112
- "Home & Garden",
113
- "Pets & Animals",
114
- "Electronics",
115
- "Gift",
116
- "Stationary",
117
- "Free time",
118
- "Chemist",
119
- ],
120
- },
121
- {
122
- "title": "Income",
123
- "subcategories": [
124
- "Income",
125
- "Wage, Invoices",
126
- "Sale",
127
- "Rental income",
128
- "Dues & grants",
129
- "Lending, renting",
130
- "Checks, coupons",
131
- "Lottery, gambling",
132
- "Refunds",
133
- "Child support",
134
- "Gifts",
135
- "Account Manage",
136
- ],
137
- },
138
- ]
139
-
140
- _categories_timeout_seconds = 15.0
141
- _model_timeout_seconds = 20.0
142
-
143
  def __init__(
144
  self,
145
  collection_getter: Callable[[], AsyncIOMotorCollection],
@@ -147,15 +27,20 @@ class AutoCategoryService:
147
  openai_client: AsyncOpenAI,
148
  model: str,
149
  cache_ttl_seconds: int,
 
 
150
  ) -> None:
151
  self._collection_getter = collection_getter
152
  self._subcategory_collection_getter = subcategory_collection_getter
153
  self._openai_client = openai_client
154
  self._model = model
155
  self._cache_ttl_seconds = cache_ttl_seconds
156
- self._cached_categories: List[Dict[str, object]] | None = None
157
- self._last_loaded: float = 0.0
158
- self._lock = asyncio.Lock()
 
 
 
159
 
160
  def _collection(self) -> AsyncIOMotorCollection:
161
  return self._collection_getter()
@@ -163,72 +48,220 @@ class AutoCategoryService:
163
  def _subcategory_collection(self) -> AsyncIOMotorCollection:
164
  return self._subcategory_collection_getter()
165
 
166
- async def categorize(self, notes: str) -> CategoryPrediction:
 
 
 
 
 
167
  try:
168
- categories = await asyncio.wait_for(
169
- self._get_categories(), timeout=self._categories_timeout_seconds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  )
171
  except asyncio.TimeoutError as exc:
172
  raise HTTPException(status_code=504, detail="Timed out loading categories from database.") from exc
173
  except Exception as exc:
174
  raise HTTPException(status_code=502, detail="Failed to load categories from database.") from exc
175
 
176
- if not categories:
177
- raise HTTPException(status_code=500, detail="No categories configured.")
178
 
179
- formatted_categories = self._format_categories(categories)
180
- user_prompt = (
 
 
 
181
  "Transaction note:\n"
182
  f"{notes}\n\n"
183
- "Available categories and subcategories:\n"
184
  f"{formatted_categories}\n\n"
185
- "Respond with the exact title and subcategory from the list above."
186
  )
187
 
188
- request_payload = dict(
189
  model=self._model,
190
- input=[
191
  {
192
  "role": "system",
193
  "content": (
194
  "You classify financial transactions into the closest category. "
195
- "Only use the provided title and subcategory options. "
196
- "Output valid JSON with keys 'title' and 'subcategory'."
197
  ),
198
  },
199
- {"role": "user", "content": [{"type": "input_text", "text": user_prompt}]},
200
  ],
201
  )
202
 
203
  try:
204
- response = await asyncio.wait_for(
205
- self._create_model_response(request_payload),
206
- timeout=self._model_timeout_seconds,
207
- )
208
- except TypeError as exc:
209
- # Older openai-python clients (pre 1.3x) do not yet support response_format.
210
- if "response_format" not in str(exc):
211
- raise
212
- response = await asyncio.wait_for(
213
- self._openai_client.responses.create(**request_payload),
214
  timeout=self._model_timeout_seconds,
215
  )
216
  except asyncio.TimeoutError as exc:
217
- raise HTTPException(status_code=504, detail="Timed out waiting for model response.") from exc
218
  except Exception as exc:
219
- raise HTTPException(status_code=502, detail="Failed to call the model API.") from exc
 
 
 
 
220
 
221
  try:
222
- payload = self._parse_response_payload(response)
223
  except ValueError as exc:
224
- raise HTTPException(status_code=502, detail="Failed to parse model output.") from exc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
- title = payload.get("title")
227
- subcategory = payload.get("subcategory")
228
- if not isinstance(title, str) or not isinstance(subcategory, str):
229
- raise HTTPException(status_code=502, detail="Model response missing category fields.")
230
 
231
- return CategoryPrediction(title=title.strip(), subcategory=subcategory.strip())
 
 
 
 
 
232
 
233
  def _parse_response_payload(self, response) -> Dict[str, object]:
234
  raw_text = self._extract_response_text(response)
@@ -252,6 +285,14 @@ class AutoCategoryService:
252
 
253
  @staticmethod
254
  def _extract_response_text(response) -> str:
 
 
 
 
 
 
 
 
255
  text = getattr(response, "output_text", "") or ""
256
  if isinstance(text, str) and text.strip():
257
  return text.strip()
@@ -328,74 +369,221 @@ class AutoCategoryService:
328
 
329
  return None
330
 
331
- async def _get_categories(self) -> List[Dict[str, object]]:
332
- async with self._lock:
 
333
  now = time.monotonic()
334
- if self._cached_categories and (now - self._last_loaded) < self._cache_ttl_seconds:
335
- return self._cached_categories
336
-
337
- if settings.use_static_categories:
338
- self._cached_categories = self._STATIC_CATEGORIES
339
- self._last_loaded = now
340
- return self._cached_categories
341
-
342
- # Use headcategories + categories to avoid scanning millions of raw transaction titles.
343
- head_collection = self._collection()
344
- subcategory_collection = self._subcategory_collection()
345
-
346
- pipeline = [
347
- {"$match": {"type": "EXPENSE", "categories": {"$type": "array", "$ne": []}}},
348
- {"$group": {"_id": "$title", "category_ids": {"$first": "$categories"}}},
349
- ]
350
- head_docs = await head_collection.aggregate(pipeline).to_list(length=1000)
351
-
352
- all_ids: set[ObjectId] = set()
353
- for doc in head_docs:
354
- for cid in doc.get("category_ids") or []:
355
- if isinstance(cid, ObjectId):
356
- all_ids.add(cid)
357
-
358
- subcategory_titles: Dict[ObjectId, str] = {}
359
- if all_ids:
360
- cursor = subcategory_collection.find({"_id": {"$in": list(all_ids)}}, {"title": 1})
361
- async for subdoc in cursor:
362
- title = subdoc.get("title")
363
- if isinstance(title, str) and title.strip():
364
- subcategory_titles[subdoc["_id"]] = title.strip()
365
-
366
- categories: List[Dict[str, object]] = []
367
- for doc in head_docs:
368
- raw_title = doc.get("_id")
369
- if not isinstance(raw_title, str) or not raw_title.strip():
370
- continue
371
-
372
- ids = [cid for cid in (doc.get("category_ids") or []) if isinstance(cid, ObjectId)]
373
- subcategories = sorted({subcategory_titles[cid] for cid in ids if cid in subcategory_titles})
374
- if not subcategories:
375
- continue
376
-
377
- categories.append(
378
- {
379
- "title": raw_title.strip(),
380
- "subcategories": subcategories,
381
- }
382
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
 
384
- self._cached_categories = categories
385
- self._last_loaded = now
386
- return categories
387
 
388
  async def _create_model_response(self, request_payload: Dict[str, object]):
389
- return await self._openai_client.responses.create(
390
- response_format={"type": "json_object"},
391
- **request_payload,
392
- )
 
 
 
 
 
 
 
 
 
 
393
 
394
  @staticmethod
395
- def _format_categories(categories: List[Dict[str, object]]) -> str:
 
396
  lines = []
397
  for category in categories:
398
  subs = category.get("subcategories") or []
399
- subs_text = ", ".join(subs) if subs else "Unspecified"
400
  lines.append(f"- {category.get('title', 'Unknown')}: {subs_text}")
401
  return "\n".join(lines)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import asyncio
5
  import json
6
  import re
7
+ import string
8
  import time
9
+ from typing import Callable, Dict, List, Optional, Tuple
10
 
11
  from bson import ObjectId
12
  from fastapi import HTTPException
 
20
  class AutoCategoryService:
21
  """Classifies transaction notes into the closest Mongo-backed category."""
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def __init__(
24
  self,
25
  collection_getter: Callable[[], AsyncIOMotorCollection],
 
27
  openai_client: AsyncOpenAI,
28
  model: str,
29
  cache_ttl_seconds: int,
30
+ db_timeout_seconds: float,
31
+ model_timeout_seconds: float,
32
  ) -> None:
33
  self._collection_getter = collection_getter
34
  self._subcategory_collection_getter = subcategory_collection_getter
35
  self._openai_client = openai_client
36
  self._model = model
37
  self._cache_ttl_seconds = cache_ttl_seconds
38
+ self._db_timeout_seconds = db_timeout_seconds
39
+ self._model_timeout_seconds = model_timeout_seconds
40
+
41
+ # User-specific cache for headcategories: {user_id: (data, timestamp)}
42
+ self._headcategories_cache: Dict[str, Tuple[Dict[str, object], float]] = {}
43
+ self._cache_lock = asyncio.Lock()
44
 
45
  def _collection(self) -> AsyncIOMotorCollection:
46
  return self._collection_getter()
 
48
  def _subcategory_collection(self) -> AsyncIOMotorCollection:
49
  return self._subcategory_collection_getter()
50
 
51
+ async def categorize(self, notes: str, user_id: str) -> CategoryPrediction:
52
+ """Categorize transaction notes using a two-step approach:
53
+ 1. First match notes to a headcategory title
54
+ 2. Then match notes to a category within that headcategory
55
+ """
56
+ # Step 1: Fetch all headcategories for the user (with caching)
57
  try:
58
+ headcategories_data = await asyncio.wait_for(
59
+ self._get_headcategories_cached(user_id), timeout=self._db_timeout_seconds
60
+ )
61
+ except asyncio.TimeoutError as exc:
62
+ raise HTTPException(status_code=504, detail="Timed out loading headcategories from database.") from exc
63
+ except Exception as exc:
64
+ raise HTTPException(status_code=502, detail="Failed to load headcategories from database.") from exc
65
+
66
+ if not headcategories_data or not headcategories_data.get("headcategories"):
67
+ raise HTTPException(status_code=500, detail="No headcategories configured for this user.")
68
+
69
+ # Step 2: Use LLM to match notes to a headcategory title
70
+ headcategory_titles = [hc.get("title", "") for hc in headcategories_data["headcategories"]]
71
+ formatted_headcategories = "\n".join([f"- {title}" for title in headcategory_titles if title])
72
+
73
+ headcategory_prompt = (
74
+ "Transaction note:\n"
75
+ f"{notes}\n\n"
76
+ "Available headcategories:\n"
77
+ f"{formatted_headcategories}\n\n"
78
+ "Respond with the exact headcategory title from the list above that best matches this transaction."
79
+ )
80
+
81
+ headcategory_request = dict(
82
+ model=self._model,
83
+ messages=[
84
+ {
85
+ "role": "system",
86
+ "content": (
87
+ "You classify financial transactions into the closest headcategory. "
88
+ "Only use the provided headcategory title options. "
89
+ "Output valid JSON with key 'title'."
90
+ ),
91
+ },
92
+ {"role": "user", "content": headcategory_prompt},
93
+ ],
94
+ )
95
+
96
+ try:
97
+ headcategory_response = await asyncio.wait_for(
98
+ self._create_model_response(headcategory_request),
99
+ timeout=self._model_timeout_seconds,
100
+ )
101
+ except asyncio.TimeoutError as exc:
102
+ raise HTTPException(status_code=504, detail="Timed out waiting for headcategory model response.") from exc
103
+ except Exception as exc:
104
+ error_msg = str(exc)
105
+ raise HTTPException(
106
+ status_code=502,
107
+ detail=f"Failed to call the model API for headcategory: {error_msg}"
108
+ ) from exc
109
+
110
+ try:
111
+ headcategory_payload = self._parse_response_payload(headcategory_response)
112
+ except ValueError as exc:
113
+ raise HTTPException(status_code=502, detail="Failed to parse headcategory model output.") from exc
114
+
115
+ matched_headcategory_title = headcategory_payload.get("title")
116
+ if not isinstance(matched_headcategory_title, str):
117
+ raise HTTPException(status_code=502, detail="Model response missing headcategory title field.")
118
+
119
+ # Step 3: Find the matched headcategory and get its categories (optimized lookup)
120
+ matched_headcategory = None
121
+ matched_title_normalized = self._normalize_string(matched_headcategory_title)
122
+ matched_title_lower = matched_headcategory_title.lower()
123
+
124
+ # Try exact normalized match first (most common case)
125
+ for hc in headcategories_data["headcategories"]:
126
+ hc_title = hc.get("title", "")
127
+ if self._normalize_string(hc_title) == matched_title_normalized:
128
+ matched_headcategory = hc
129
+ break
130
+
131
+ # Try partial matching if exact normalized match fails
132
+ if not matched_headcategory:
133
+ for hc in headcategories_data["headcategories"]:
134
+ hc_title = hc.get("title", "").lower()
135
+ if matched_title_lower in hc_title or hc_title in matched_title_lower:
136
+ matched_headcategory = hc
137
+ break
138
+
139
+ if not matched_headcategory:
140
+ available_titles = ", ".join(headcategory_titles[:10])
141
+ raise HTTPException(
142
+ status_code=502,
143
+ detail=(
144
+ f"Could not find matching headcategory for title: '{matched_headcategory_title}'. "
145
+ f"Available headcategories: {available_titles}"
146
+ )
147
+ )
148
+
149
+ headcategory_id = matched_headcategory.get("_id")
150
+ category_ids = matched_headcategory.get("category_ids", [])
151
+
152
+ if not isinstance(headcategory_id, ObjectId):
153
+ raise HTTPException(status_code=500, detail="Invalid headcategory ID format.")
154
+
155
+ if not category_ids:
156
+ raise HTTPException(status_code=500, detail="Selected headcategory has no categories.")
157
+
158
+ # Step 4: Fetch categories from categories collection
159
+ try:
160
+ categories_data = await asyncio.wait_for(
161
+ self._get_categories_by_ids(category_ids), timeout=self._db_timeout_seconds
162
  )
163
  except asyncio.TimeoutError as exc:
164
  raise HTTPException(status_code=504, detail="Timed out loading categories from database.") from exc
165
  except Exception as exc:
166
  raise HTTPException(status_code=502, detail="Failed to load categories from database.") from exc
167
 
168
+ if not categories_data or not categories_data.get("categories"):
169
+ raise HTTPException(status_code=500, detail="No categories found for the selected headcategory.")
170
 
171
+ # Step 5: Use LLM to match notes to a specific category
172
+ category_titles = [cat.get("title", "") for cat in categories_data["categories"]]
173
+ formatted_categories = "\n".join([f"- {title}" for title in category_titles if title])
174
+
175
+ category_prompt = (
176
  "Transaction note:\n"
177
  f"{notes}\n\n"
178
+ "Available categories:\n"
179
  f"{formatted_categories}\n\n"
180
+ "Respond with the exact category title from the list above that best matches this transaction."
181
  )
182
 
183
+ category_request = dict(
184
  model=self._model,
185
+ messages=[
186
  {
187
  "role": "system",
188
  "content": (
189
  "You classify financial transactions into the closest category. "
190
+ "Only use the provided category title options. "
191
+ "Output valid JSON with key 'title'."
192
  ),
193
  },
194
+ {"role": "user", "content": category_prompt},
195
  ],
196
  )
197
 
198
  try:
199
+ category_response = await asyncio.wait_for(
200
+ self._create_model_response(category_request),
 
 
 
 
 
 
 
 
201
  timeout=self._model_timeout_seconds,
202
  )
203
  except asyncio.TimeoutError as exc:
204
+ raise HTTPException(status_code=504, detail="Timed out waiting for category model response.") from exc
205
  except Exception as exc:
206
+ error_msg = str(exc)
207
+ raise HTTPException(
208
+ status_code=502,
209
+ detail=f"Failed to call the model API for category: {error_msg}"
210
+ ) from exc
211
 
212
  try:
213
+ category_payload = self._parse_response_payload(category_response)
214
  except ValueError as exc:
215
+ raise HTTPException(status_code=502, detail="Failed to parse category model output.") from exc
216
+
217
+ matched_category_title = category_payload.get("title")
218
+ if not isinstance(matched_category_title, str):
219
+ raise HTTPException(status_code=502, detail="Model response missing category title field.")
220
+
221
+ # Step 6: Find the matched category ID (optimized lookup)
222
+ matched_category = None
223
+ matched_cat_title_normalized = self._normalize_string(matched_category_title)
224
+ matched_cat_title_lower = matched_category_title.lower()
225
+
226
+ # Try exact normalized match first (most common case)
227
+ for cat in categories_data["categories"]:
228
+ cat_title = cat.get("title", "")
229
+ if self._normalize_string(cat_title) == matched_cat_title_normalized:
230
+ matched_category = cat
231
+ break
232
+
233
+ # Try partial matching if exact normalized match fails
234
+ if not matched_category:
235
+ for cat in categories_data["categories"]:
236
+ cat_title = cat.get("title", "").lower()
237
+ if matched_cat_title_lower in cat_title or cat_title in matched_cat_title_lower:
238
+ matched_category = cat
239
+ break
240
+
241
+ if not matched_category:
242
+ available_titles = ", ".join(category_titles[:10])
243
+ raise HTTPException(
244
+ status_code=502,
245
+ detail=(
246
+ f"Could not find matching category for title: '{matched_category_title}'. "
247
+ f"Available categories: {available_titles}"
248
+ )
249
+ )
250
+
251
+ category_id = matched_category.get("_id")
252
+ if not isinstance(category_id, ObjectId):
253
+ raise HTTPException(status_code=500, detail="Invalid category ID format.")
254
 
255
+ # Get titles from matched objects
256
+ headcategory_title = matched_headcategory.get("title", "")
257
+ category_title = matched_category.get("title", "")
 
258
 
259
+ return CategoryPrediction(
260
+ headcategory_id=str(headcategory_id),
261
+ headcategory_title=headcategory_title,
262
+ category_id=str(category_id),
263
+ category_title=category_title,
264
+ )
265
 
266
  def _parse_response_payload(self, response) -> Dict[str, object]:
267
  raw_text = self._extract_response_text(response)
 
285
 
286
  @staticmethod
287
  def _extract_response_text(response) -> str:
288
+ """Extract text from OpenAI API response (supports both Chat Completions and Responses API)."""
289
+ # Try standard Chat Completions API format first
290
+ if hasattr(response, "choices") and response.choices:
291
+ message = response.choices[0].message
292
+ if hasattr(message, "content") and message.content:
293
+ return message.content.strip()
294
+
295
+ # Try Responses API format
296
  text = getattr(response, "output_text", "") or ""
297
  if isinstance(text, str) and text.strip():
298
  return text.strip()
 
369
 
370
  return None
371
 
372
+ async def _get_headcategories_cached(self, user_id: str) -> Dict[str, object]:
373
+ """Fetch headcategories from MongoDB with user-specific caching."""
374
+ async with self._cache_lock:
375
  now = time.monotonic()
376
+ # Check cache
377
+ if user_id in self._headcategories_cache:
378
+ cached_data, cached_time = self._headcategories_cache[user_id]
379
+ if (now - cached_time) < self._cache_ttl_seconds:
380
+ return cached_data
381
+ # Cache expired, remove it
382
+ del self._headcategories_cache[user_id]
383
+
384
+ # Fetch from database
385
+ data = await self._get_headcategories(user_id)
386
+
387
+ # Update cache
388
+ async with self._cache_lock:
389
+ self._headcategories_cache[user_id] = (data, time.monotonic())
390
+
391
+ return data
392
+
393
+ async def _get_headcategories(self, user_id: str) -> Dict[str, object]:
394
+ """Fetch headcategories from MongoDB filtered by user_id."""
395
+ head_collection = self._collection()
396
+
397
+ # Convert user_id string to ObjectId
398
+ try:
399
+ user_object_id = ObjectId(user_id)
400
+ except Exception as exc:
401
+ raise HTTPException(status_code=400, detail=f"Invalid user_id format: {user_id}") from exc
402
+
403
+ # Query headcategories filtered by user_id - only fetch needed fields for performance
404
+ head_docs = await head_collection.find(
405
+ {"user": user_object_id, "categories": {"$type": "array", "$ne": []}},
406
+ {"_id": 1, "title": 1, "categories": 1} # Only fetch needed fields
407
+ ).to_list(length=1000)
408
+
409
+ if not head_docs:
410
+ return {"headcategories": []}
411
+
412
+ # Build headcategories structure
413
+ headcategories: List[Dict[str, object]] = []
414
+ for head_doc in head_docs:
415
+ head_id = head_doc.get("_id")
416
+ if not isinstance(head_id, ObjectId):
417
+ continue
418
+
419
+ category_ids = [cid for cid in (head_doc.get("categories") or []) if isinstance(cid, ObjectId)]
420
+ if not category_ids:
421
+ continue
422
+
423
+ headcategories.append({
424
+ "_id": head_id,
425
+ "title": head_doc.get("title", ""),
426
+ "category_ids": category_ids,
427
+ })
428
+
429
+ return {"headcategories": headcategories}
430
+
431
+ async def _get_categories_by_ids(self, category_ids: List[ObjectId]) -> Dict[str, object]:
432
+ """Fetch categories from MongoDB by their ObjectIds."""
433
+ subcategory_collection = self._subcategory_collection()
434
+
435
+ if not category_ids:
436
+ return {"categories": []}
437
+
438
+ # Query categories collection with the provided ObjectIds
439
+ categories: List[Dict[str, object]] = []
440
+ cursor = subcategory_collection.find(
441
+ {"_id": {"$in": category_ids}},
442
+ {"title": 1, "_id": 1}
443
+ )
444
+ async for cat_doc in cursor:
445
+ cat_id = cat_doc.get("_id")
446
+ if isinstance(cat_id, ObjectId):
447
+ categories.append({
448
+ "_id": cat_id,
449
+ "title": cat_doc.get("title", ""),
450
+ })
451
 
452
+ return {"categories": categories}
 
 
453
 
454
  async def _create_model_response(self, request_payload: Dict[str, object]):
455
+ """Create a model response using OpenAI Chat Completions API."""
456
+ try:
457
+ return await self._openai_client.chat.completions.create(
458
+ response_format={"type": "json_object"},
459
+ **request_payload,
460
+ )
461
+ except TypeError as exc:
462
+ # Fallback for older openai-python clients or custom API endpoints
463
+ if "responses" in dir(self._openai_client):
464
+ return await self._openai_client.responses.create(
465
+ response_format={"type": "json_object"},
466
+ **request_payload,
467
+ )
468
+ raise
469
 
470
  @staticmethod
471
+ def _format_categories_for_llm(categories: List[Dict[str, object]]) -> str:
472
+ """Format categories for LLM prompt."""
473
  lines = []
474
  for category in categories:
475
  subs = category.get("subcategories") or []
476
+ subs_text = ", ".join([sub.get("title", "") for sub in subs if isinstance(sub, dict)]) if subs else "Unspecified"
477
  lines.append(f"- {category.get('title', 'Unknown')}: {subs_text}")
478
  return "\n".join(lines)
479
+
480
+ @staticmethod
481
+ def _normalize_string(s: str) -> str:
482
+ """Normalize string by removing punctuation and extra spaces for better matching."""
483
+ # Remove punctuation and convert to lowercase
484
+ normalized = s.translate(str.maketrans('', '', string.punctuation)).lower().strip()
485
+ # Replace multiple spaces with single space
486
+ normalized = ' '.join(normalized.split())
487
+ return normalized
488
+
489
+ @staticmethod
490
+ def _find_matching_ids(
491
+ categories: List[Dict[str, object]],
492
+ title: str,
493
+ subcategory: str
494
+ ) -> tuple[ObjectId | None, ObjectId | None]:
495
+ """Find matching headcategory_id and category_id based on title and subcategory strings.
496
+
497
+ Uses flexible matching:
498
+ 1. Exact match (case-insensitive)
499
+ 2. Normalized match (removes punctuation)
500
+ 3. Partial match (one contains the other)
501
+ 4. Word-based match (checks if key words match)
502
+ """
503
+ title_lower = title.strip().lower()
504
+ subcategory_lower = subcategory.strip().lower()
505
+ title_normalized = AutoCategoryService._normalize_string(title)
506
+ subcategory_normalized = AutoCategoryService._normalize_string(subcategory)
507
+
508
+ # First pass: exact match
509
+ for category in categories:
510
+ head_title = category.get("title", "").strip().lower()
511
+ if head_title != title_lower:
512
+ continue
513
+
514
+ subcategories = category.get("subcategories", [])
515
+ for sub in subcategories:
516
+ if isinstance(sub, dict):
517
+ sub_title = sub.get("title", "").strip().lower()
518
+ if sub_title == subcategory_lower:
519
+ headcategory_id = category.get("headcategory_id")
520
+ category_id = sub.get("_id")
521
+ if isinstance(headcategory_id, ObjectId) and isinstance(category_id, ObjectId):
522
+ return headcategory_id, category_id
523
+
524
+ # Second pass: normalized match (removes punctuation, handles "Wage" vs "Wage, Invoices")
525
+ for category in categories:
526
+ head_title = category.get("title", "").strip().lower()
527
+ head_title_norm = AutoCategoryService._normalize_string(head_title)
528
+ if head_title_norm != title_normalized and title_normalized not in head_title_norm and head_title_norm not in title_normalized:
529
+ continue
530
+
531
+ subcategories = category.get("subcategories", [])
532
+ for sub in subcategories:
533
+ if isinstance(sub, dict):
534
+ sub_title = sub.get("title", "").strip().lower()
535
+ sub_title_norm = AutoCategoryService._normalize_string(sub_title)
536
+ if (sub_title_norm == subcategory_normalized or
537
+ subcategory_normalized in sub_title_norm or
538
+ sub_title_norm in subcategory_normalized):
539
+ headcategory_id = category.get("headcategory_id")
540
+ category_id = sub.get("_id")
541
+ if isinstance(headcategory_id, ObjectId) and isinstance(category_id, ObjectId):
542
+ return headcategory_id, category_id
543
+
544
+ # Third pass: partial match (one contains the other)
545
+ for category in categories:
546
+ head_title = category.get("title", "").strip().lower()
547
+ # Check if title matches (exact or contains)
548
+ if title_lower not in head_title and head_title not in title_lower:
549
+ continue
550
+
551
+ subcategories = category.get("subcategories", [])
552
+ for sub in subcategories:
553
+ if isinstance(sub, dict):
554
+ sub_title = sub.get("title", "").strip().lower()
555
+ # Check if subcategory matches (exact or contains)
556
+ if (subcategory_lower in sub_title or sub_title in subcategory_lower or
557
+ subcategory_lower.split()[0] in sub_title or sub_title.split()[0] in subcategory_lower):
558
+ headcategory_id = category.get("headcategory_id")
559
+ category_id = sub.get("_id")
560
+ if isinstance(headcategory_id, ObjectId) and isinstance(category_id, ObjectId):
561
+ return headcategory_id, category_id
562
+
563
+ # Fourth pass: word-based matching (for cases like "Wage" matching "Wage, Invoices")
564
+ title_words = set(title_lower.split())
565
+ subcategory_words = set(subcategory_lower.split())
566
+
567
+ for category in categories:
568
+ head_title = category.get("title", "").strip().lower()
569
+ head_title_words = set(head_title.split())
570
+
571
+ # Check if there's significant word overlap for title
572
+ if not title_words.intersection(head_title_words) and not head_title_words.intersection(title_words):
573
+ continue
574
+
575
+ subcategories = category.get("subcategories", [])
576
+ for sub in subcategories:
577
+ if isinstance(sub, dict):
578
+ sub_title = sub.get("title", "").strip().lower()
579
+ sub_title_words = set(sub_title.split())
580
+
581
+ # Check if there's significant word overlap for subcategory
582
+ if (subcategory_words.intersection(sub_title_words) or
583
+ sub_title_words.intersection(subcategory_words)):
584
+ headcategory_id = category.get("headcategory_id")
585
+ category_id = sub.get("_id")
586
+ if isinstance(headcategory_id, ObjectId) and isinstance(category_id, ObjectId):
587
+ return headcategory_id, category_id
588
+
589
+ return None, None