mabosaimi commited on
Commit
e0a827b
·
1 Parent(s): 048f0fe

feat: add routes with proper documentation

Browse files
Files changed (5) hide show
  1. Dockerfile +2 -0
  2. app.py +90 -10
  3. requirements.txt +3 -2
  4. schemas.json +270 -0
  5. utils.py +115 -0
Dockerfile CHANGED
@@ -3,6 +3,8 @@ FROM python:3.12-slim
3
  RUN useradd -m -u 1000 user
4
  USER user
5
  ENV PATH="/home/user/.local/bin:$PATH"
 
 
6
 
7
  WORKDIR /app
8
 
 
3
  RUN useradd -m -u 1000 user
4
  USER user
5
  ENV PATH="/home/user/.local/bin:$PATH"
6
+ ENV TOKENIZERS_PARALLELISM=false \
7
+ HF_HUB_DISABLE_TELEMETRY=1
8
 
9
  WORKDIR /app
10
 
app.py CHANGED
@@ -1,16 +1,96 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from sentence_transformers import SentenceTransformer
4
 
5
- app = FastAPI()
6
- model = SentenceTransformer("mabosaimi/bge-m3-text2tables")
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- class Query(BaseModel):
10
  query: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
 
 
 
 
12
 
13
- @app.post("/embed")
14
- async def embed_query(data: Query):
15
- embedding = model.encode(data.query)
16
- return {"embedding": embedding.tolist()}
 
 
 
 
 
1
+ from typing import List, Optional
 
 
2
 
3
+ from fastapi import FastAPI, Query
4
+ from pydantic import BaseModel, Field
5
 
6
+ from utils import (
7
+ semantic_search,
8
+ get_schemas,
9
+ get_model_id,
10
+ get_corpus_size,
11
+ )
12
+
13
+
14
+ app = FastAPI(title="Semantic Table Search API", version="1.0.0")
15
+
16
+
17
+ class SearchRequest(BaseModel):
18
+ """Client request schema for semantic search.
19
+
20
+ Attributes:
21
+ - query: Natural language search text.
22
+ - limit: Optional max number of results to return (fallback to query param
23
+ when not provided). Included here to allow sending within body when
24
+ clients prefer JSON-only interactions.
25
+ """
26
+
27
+ query: str = Field(..., min_length=1)
28
+ limit: Optional[int] = Field(None, ge=1, le=50)
29
+
30
+
31
+ class Match(BaseModel):
32
+ """A single search match result."""
33
+
34
+ score: float = Field(..., description="Cosine similarity score (-1 to 1)")
35
+ text: str = Field(..., description="Matched table metadata text")
36
+ index: int = Field(..., description="Stable index of the matched corpus entry")
37
+
38
+
39
+ class SearchResponse(BaseModel):
40
+ """Search results with basic service metadata."""
41
 
 
42
  query: str
43
+ results: List[Match]
44
+ count: int
45
+ limit: int
46
+
47
+
48
+ @app.get("/health")
49
+ def health() -> dict:
50
+ """Basic health check including corpus size.
51
+
52
+ Returns a JSON indicating the service is up along with a few
53
+ diagnostic fields.
54
+ """
55
+
56
+ return {
57
+ "status": "ok",
58
+ "corpus_size": get_corpus_size(),
59
+ "model": get_model_id(),
60
+ }
61
+
62
+
63
+ @app.get("/schemas")
64
+ def schemas(
65
+ include_columns: bool = Query(False, description="Include column metadata"),
66
+ ) -> List[dict]:
67
+ """List available table schemas and optional column metadata.
68
+
69
+ Parameters:
70
+ - include_columns: When true, return full schema definitions; otherwise a
71
+ compact view containing table names and descriptions is returned.
72
+ """
73
+
74
+ return get_schemas(include_columns=include_columns)
75
+
76
+
77
+ @app.post("/search", response_model=SearchResponse)
78
+ def search(
79
+ body: SearchRequest,
80
+ limit: int = Query(5, ge=1, le=50, description="Max number of results"),
81
+ ) -> SearchResponse:
82
+ """Perform a semantic search over table metadata and return ranked matches.
83
 
84
+ The clients provide a natural language
85
+ query and receive the most relevant tables with similarity
86
+ scores and stable corpus indices.
87
+ """
88
 
89
+ effective_limit = body.limit or limit
90
+ results = semantic_search(body.query, top_k=effective_limit)
91
+ matches = [
92
+ {"score": score, "text": text, "index": idx} for score, text, idx in results
93
+ ]
94
+ return SearchResponse(
95
+ query=body.query, results=matches, count=len(matches), limit=effective_limit
96
+ )
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
- fastapi==0.95.2
2
- uvicorn[standard]==0.23.1
 
3
  sentence-transformers==5.1.0
4
  transformers==4.56.1
5
  torch==2.8.0
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ huggingface_hub
4
  sentence-transformers==5.1.0
5
  transformers==4.56.1
6
  torch==2.8.0
schemas.json ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "table": "customers",
4
+ "description": "People or organizations that purchase or may purchase goods/services.",
5
+ "columns": [
6
+ {"name": "customer_id", "type": "BIGINT", "description": "Surrogate primary key for the customer."},
7
+ {"name": "account_type", "type": "VARCHAR(20)", "description": "Type such as consumer or business."},
8
+ {"name": "first_name", "type": "VARCHAR(80)", "description": "Given name for contact purposes."},
9
+ {"name": "last_name", "type": "VARCHAR(80)", "description": "Family name for contact purposes."},
10
+ {"name": "email", "type": "VARCHAR(255)", "description": "Primary contact email; unique when present."},
11
+ {"name": "phone", "type": "VARCHAR(32)", "description": "Primary contact phone number."},
12
+ {"name": "address_id", "type": "BIGINT", "description": "FK to customer_addresses for canonical address."},
13
+ {"name": "country", "type": "CHAR(2)", "description": "ISO‑2 country code of primary address."},
14
+ {"name": "registered_at", "type": "TIMESTAMP", "description": "Timestamp when the customer registered."},
15
+ {"name": "status", "type": "VARCHAR(20)", "description": "Lifecycle state such as active, prospect, churned."}
16
+ ]
17
+ },
18
+ {
19
+ "table": "customer_addresses",
20
+ "description": "Normalized postal addresses for customers, suppliers, and locations.",
21
+ "columns": [
22
+ {"name": "address_id", "type": "BIGINT", "description": "Primary key for the address record."},
23
+ {"name": "line1", "type": "VARCHAR(120)", "description": "Street line 1."},
24
+ {"name": "line2", "type": "VARCHAR(120)", "description": "Street line 2 or unit/suite (nullable)."},
25
+ {"name": "city", "type": "VARCHAR(80)", "description": "City or locality."},
26
+ {"name": "state", "type": "VARCHAR(80)", "description": "Region/state/province."},
27
+ {"name": "postal_code", "type": "VARCHAR(20)", "description": "Postal/ZIP code."},
28
+ {"name": "country", "type": "CHAR(2)", "description": "ISO‑2 country code."},
29
+ {"name": "latitude", "type": "DECIMAL(9,6)", "description": "Latitude for geocoding (nullable)."},
30
+ {"name": "longitude", "type": "DECIMAL(9,6)", "description": "Longitude for geocoding (nullable)."},
31
+ {"name": "valid_from", "type": "DATE", "description": "Start date the address is valid from."},
32
+ {"name": "valid_to", "type": "DATE", "description": "End date the address is valid to (nullable)."}
33
+ ]
34
+ },
35
+ {
36
+ "table": "products",
37
+ "description": "Catalog of sellable items or services.",
38
+ "columns": [
39
+ {"name": "product_id", "type": "BIGINT", "description": "Primary key for the product."},
40
+ {"name": "sku", "type": "VARCHAR(64)", "description": "Stock keeping unit code; unique per product."},
41
+ {"name": "product_name", "type": "VARCHAR(160)", "description": "Marketing name of the product."},
42
+ {"name": "category_id", "type": "BIGINT", "description": "FK to categories for hierarchy grouping."},
43
+ {"name": "unit_price", "type": "DECIMAL(12,2)", "description": "List price in the default currency."},
44
+ {"name": "currency", "type": "CHAR(3)", "description": "ISO‑4217 currency code for unit_price."},
45
+ {"name": "active", "type": "BOOLEAN", "description": "If the product is available for sale."},
46
+ {"name": "created_at", "type": "TIMESTAMP", "description": "Record creation time."},
47
+ {"name": "updated_at", "type": "TIMESTAMP", "description": "Last update time."}
48
+ ]
49
+ },
50
+ {
51
+ "table": "categories",
52
+ "description": "Product category hierarchy for analytics and navigation.",
53
+ "columns": [
54
+ {"name": "category_id", "type": "BIGINT", "description": "Primary key for the category."},
55
+ {"name": "category_name", "type": "VARCHAR(120)", "description": "Human‑readable category label."},
56
+ {"name": "parent_category_id", "type": "BIGINT", "description": "Self‑reference to parent category (nullable)."},
57
+ {"name": "description", "type": "TEXT", "description": "Longer description for the category (nullable)."}
58
+ ]
59
+ },
60
+ {
61
+ "table": "orders",
62
+ "description": "Sales orders placed by customers; header level details.",
63
+ "columns": [
64
+ {"name": "order_id", "type": "BIGINT", "description": "Primary key for the order."},
65
+ {"name": "customer_id", "type": "BIGINT", "description": "FK to customers who placed the order."},
66
+ {"name": "order_date", "type": "TIMESTAMP", "description": "Time the order was submitted."},
67
+ {"name": "status", "type": "VARCHAR(20)", "description": "Order state such as pending, paid, shipped."},
68
+ {"name": "payment_method", "type": "VARCHAR(20)", "description": "Method like card, bank, wallet."},
69
+ {"name": "order_total", "type": "DECIMAL(12,2)", "description": "Total monetary value of the order."},
70
+ {"name": "currency", "type": "CHAR(3)", "description": "Currency for order_total."},
71
+ {"name": "shipping_address_id", "type": "BIGINT", "description": "FK to customer_addresses for shipping."}
72
+ ]
73
+ },
74
+ {
75
+ "table": "order_items",
76
+ "description": "Line‑level items linking orders to products.",
77
+ "columns": [
78
+ {"name": "order_item_id", "type": "BIGINT", "description": "Primary key for the line item."},
79
+ {"name": "order_id", "type": "BIGINT", "description": "FK to orders header."},
80
+ {"name": "product_id", "type": "BIGINT", "description": "FK to products catalog."},
81
+ {"name": "quantity", "type": "INT", "description": "Units ordered for this product."},
82
+ {"name": "unit_price", "type": "DECIMAL(12,2)", "description": "Unit price at time of sale."},
83
+ {"name": "discount", "type": "DECIMAL(5,2)", "description": "Percentage discount applied (0–100)."},
84
+ {"name": "line_total", "type": "DECIMAL(12,2)", "description": "Extended price after discount."}
85
+ ]
86
+ },
87
+ {
88
+ "table": "payments",
89
+ "description": "Customer payments applied to orders or invoices.",
90
+ "columns": [
91
+ {"name": "payment_id", "type": "BIGINT", "description": "Primary key for the payment."},
92
+ {"name": "order_id", "type": "BIGINT", "description": "FK to orders being paid (nullable if invoice_id used)."},
93
+ {"name": "invoice_id", "type": "BIGINT", "description": "FK to invoices (nullable if order_id used)."},
94
+ {"name": "payment_date", "type": "TIMESTAMP", "description": "Time the payment was captured."},
95
+ {"name": "amount", "type": "DECIMAL(12,2)", "description": "Amount received in payment."},
96
+ {"name": "currency", "type": "CHAR(3)", "description": "Currency for the amount."},
97
+ {"name": "method", "type": "VARCHAR(20)", "description": "Card, bank_transfer, wallet, cash, etc."},
98
+ {"name": "status", "type": "VARCHAR(20)", "description": "authorized, captured, refunded, failed."}
99
+ ]
100
+ },
101
+ {
102
+ "table": "invoices",
103
+ "description": "Billing documents issued to customers for accounting.",
104
+ "columns": [
105
+ {"name": "invoice_id", "type": "BIGINT", "description": "Primary key for the invoice."},
106
+ {"name": "customer_id", "type": "BIGINT", "description": "FK to customers billed."},
107
+ {"name": "invoice_date", "type": "DATE", "description": "Date the invoice was issued."},
108
+ {"name": "due_date", "type": "DATE", "description": "Payment due date per terms."},
109
+ {"name": "total_due", "type": "DECIMAL(12,2)", "description": "Total amount due on the invoice."},
110
+ {"name": "currency", "type": "CHAR(3)", "description": "Currency for total_due."},
111
+ {"name": "status", "type": "VARCHAR(20)", "description": "open, paid, overdue, cancelled."}
112
+ ]
113
+ },
114
+ {
115
+ "table": "subscriptions",
116
+ "description": "Recurring customer subscriptions for SaaS or services.",
117
+ "columns": [
118
+ {"name": "subscription_id", "type": "BIGINT", "description": "Primary key for the subscription."},
119
+ {"name": "customer_id", "type": "BIGINT", "description": "FK to customers holding the subscription."},
120
+ {"name": "plan_id", "type": "BIGINT", "description": "FK to subscription_plans."},
121
+ {"name": "start_date", "type": "DATE", "description": "Start date of the subscription."},
122
+ {"name": "end_date", "type": "DATE", "description": "End date or null if ongoing."},
123
+ {"name": "status", "type": "VARCHAR(20)", "description": "active, paused, cancelled, expired."},
124
+ {"name": "auto_renew", "type": "BOOLEAN", "description": "Whether the subscription auto‑renews."}
125
+ ]
126
+ },
127
+ {
128
+ "table": "subscription_plans",
129
+ "description": "Catalog of subscription plans and pricing tiers.",
130
+ "columns": [
131
+ {"name": "plan_id", "type": "BIGINT", "description": "Primary key for the plan."},
132
+ {"name": "plan_name", "type": "VARCHAR(80)", "description": "Marketing name for the plan."},
133
+ {"name": "billing_cycle", "type": "VARCHAR(20)", "description": "monthly, yearly, etc."},
134
+ {"name": "price", "type": "DECIMAL(12,2)", "description": "Price per billing cycle."},
135
+ {"name": "currency", "type": "CHAR(3)", "description": "Currency for price."},
136
+ {"name": "features", "type": "JSON", "description": "Feature flags or limits for the plan."}
137
+ ]
138
+ },
139
+ {
140
+ "table": "employees",
141
+ "description": "Company staff directory for HR and RBAC.",
142
+ "columns": [
143
+ {"name": "employee_id", "type": "BIGINT", "description": "Primary key for employee."},
144
+ {"name": "first_name", "type": "VARCHAR(80)", "description": "Given name."},
145
+ {"name": "last_name", "type": "VARCHAR(80)", "description": "Family name."},
146
+ {"name": "email", "type": "VARCHAR(255)", "description": "Work email address (unique)."},
147
+ {"name": "phone", "type": "VARCHAR(32)", "description": "Work phone number."},
148
+ {"name": "hire_date", "type": "DATE", "description": "Date employee joined."},
149
+ {"name": "job_title", "type": "VARCHAR(120)", "description": "Official job title."},
150
+ {"name": "department_id", "type": "BIGINT", "description": "FK to departments."},
151
+ {"name": "manager_id", "type": "BIGINT", "description": "Self‑FK to the manager employee_id (nullable)."}
152
+ ]
153
+ },
154
+ {
155
+ "table": "departments",
156
+ "description": "Organizational units for budgeting and reporting.",
157
+ "columns": [
158
+ {"name": "department_id", "type": "BIGINT", "description": "Primary key for the department."},
159
+ {"name": "department_name", "type": "VARCHAR(120)", "description": "Name of the department."},
160
+ {"name": "cost_center", "type": "VARCHAR(32)", "description": "Accounting cost center code."},
161
+ {"name": "manager_id", "type": "BIGINT", "description": "FK to employees who manage the department."}
162
+ ]
163
+ },
164
+ {
165
+ "table": "projects",
166
+ "description": "Portfolio of internal or client projects with budgets and timelines.",
167
+ "columns": [
168
+ {"name": "project_id", "type": "BIGINT", "description": "Primary key for the project."},
169
+ {"name": "project_name", "type": "VARCHAR(160)", "description": "Short name of the project."},
170
+ {"name": "sponsor_department_id", "type": "BIGINT", "description": "FK to departments sponsoring the work."},
171
+ {"name": "start_date", "type": "DATE", "description": "Planned or actual start date."},
172
+ {"name": "end_date", "type": "DATE", "description": "Planned or actual end date (nullable)."},
173
+ {"name": "budget", "type": "DECIMAL(14,2)", "description": "Approved budget for the project."},
174
+ {"name": "status", "type": "VARCHAR(20)", "description": "planned, active, on_hold, complete."}
175
+ ]
176
+ },
177
+ {
178
+ "table": "tasks",
179
+ "description": "Executable work items under projects.",
180
+ "columns": [
181
+ {"name": "task_id", "type": "BIGINT", "description": "Primary key for the task."},
182
+ {"name": "project_id", "type": "BIGINT", "description": "FK to projects."},
183
+ {"name": "task_name", "type": "VARCHAR(160)", "description": "Short description of the task."},
184
+ {"name": "assignee_employee_id", "type": "BIGINT", "description": "FK to employees assigned."},
185
+ {"name": "due_date", "type": "DATE", "description": "Target completion date."},
186
+ {"name": "status", "type": "VARCHAR(20)", "description": "todo, in_progress, blocked, done."},
187
+ {"name": "priority", "type": "VARCHAR(10)", "description": "low, medium, high, urgent."}
188
+ ]
189
+ },
190
+ {
191
+ "table": "support_tickets",
192
+ "description": "Customer support issues tracked by the service team.",
193
+ "columns": [
194
+ {"name": "ticket_id", "type": "BIGINT", "description": "Primary key for the ticket."},
195
+ {"name": "customer_id", "type": "BIGINT", "description": "FK to customers who opened the ticket."},
196
+ {"name": "subject", "type": "VARCHAR(160)", "description": "Short title summarizing the issue."},
197
+ {"name": "description", "type": "TEXT", "description": "Detailed problem description."},
198
+ {"name": "priority", "type": "VARCHAR(10)", "description": "low, medium, high, urgent."},
199
+ {"name": "status", "type": "VARCHAR(20)", "description": "open, pending, on_hold, resolved, closed."},
200
+ {"name": "opened_at", "type": "TIMESTAMP", "description": "When the ticket was created."},
201
+ {"name": "closed_at", "type": "TIMESTAMP", "description": "When the ticket was closed (nullable)."}
202
+ ]
203
+ },
204
+ {
205
+ "table": "web_sessions",
206
+ "description": "Website/app sessions for digital analytics.",
207
+ "columns": [
208
+ {"name": "session_id", "type": "VARCHAR(64)", "description": "Client session identifier."},
209
+ {"name": "visitor_id", "type": "VARCHAR(64)", "description": "Anonymous or known user identifier."},
210
+ {"name": "started_at", "type": "TIMESTAMP", "description": "Session start time."},
211
+ {"name": "ended_at", "type": "TIMESTAMP", "description": "Session end time (nullable)."},
212
+ {"name": "source", "type": "VARCHAR(40)", "description": "Traffic source/medium or campaign."},
213
+ {"name": "device", "type": "VARCHAR(40)", "description": "Device class such as mobile or desktop."},
214
+ {"name": "country", "type": "CHAR(2)", "description": "ISO‑2 country of the session."}
215
+ ]
216
+ },
217
+ {
218
+ "table": "marketing_campaigns",
219
+ "description": "Planned and active campaigns for acquisition and retention.",
220
+ "columns": [
221
+ {"name": "campaign_id", "type": "BIGINT", "description": "Primary key for the campaign."},
222
+ {"name": "campaign_name", "type": "VARCHAR(160)", "description": "Human‑readable campaign label."},
223
+ {"name": "channel", "type": "VARCHAR(40)", "description": "email, ads, social, affiliates, etc."},
224
+ {"name": "budget", "type": "DECIMAL(14,2)", "description": "Allocated spend for the campaign."},
225
+ {"name": "currency", "type": "CHAR(3)", "description": "Currency of budget."},
226
+ {"name": "start_date", "type": "DATE", "description": "Campaign start date."},
227
+ {"name": "end_date", "type": "DATE", "description": "Campaign end date (nullable)."}
228
+ ]
229
+ },
230
+ {
231
+ "table": "leads",
232
+ "description": "Prospective customers captured by marketing or sales.",
233
+ "columns": [
234
+ {"name": "lead_id", "type": "BIGINT", "description": "Primary key for the lead."},
235
+ {"name": "source", "type": "VARCHAR(40)", "description": "Origin of the lead such as web, event, referral."},
236
+ {"name": "first_name", "type": "VARCHAR(80)", "description": "Lead’s given name."},
237
+ {"name": "last_name", "type": "VARCHAR(80)", "description": "Lead’s family name."},
238
+ {"name": "email", "type": "VARCHAR(255)", "description": "Contact email address (nullable)."},
239
+ {"name": "company", "type": "VARCHAR(160)", "description": "Company name if B2B (nullable)."},
240
+ {"name": "status", "type": "VARCHAR(20)", "description": "new, qualified, unqualified, converted."},
241
+ {"name": "created_at", "type": "TIMESTAMP", "description": "Time the lead was created."}
242
+ ]
243
+ },
244
+ {
245
+ "table": "crm_interactions",
246
+ "description": "Logged emails, calls, and meetings with leads or customers.",
247
+ "columns": [
248
+ {"name": "interaction_id", "type": "BIGINT", "description": "Primary key for the interaction."},
249
+ {"name": "actor_employee_id", "type": "BIGINT", "description": "FK to employees who performed the interaction."},
250
+ {"name": "lead_id", "type": "BIGINT", "description": "FK to leads (nullable if customer_id used)."},
251
+ {"name": "customer_id", "type": "BIGINT", "description": "FK to customers (nullable if lead_id used)."},
252
+ {"name": "channel", "type": "VARCHAR(20)", "description": "email, call, meeting, chat, etc."},
253
+ {"name": "occurred_at", "type": "TIMESTAMP", "description": "When the interaction occurred."},
254
+ {"name": "notes", "type": "TEXT", "description": "Free‑form summary of the interaction."}
255
+ ]
256
+ },
257
+ {
258
+ "table": "inventory",
259
+ "description": "Current and reserved stock levels by product and location.",
260
+ "columns": [
261
+ {"name": "inventory_id", "type": "BIGINT", "description": "Primary key for the inventory record."},
262
+ {"name": "product_id", "type": "BIGINT", "description": "FK to products."},
263
+ {"name": "location_id", "type": "BIGINT", "description": "FK to locations (warehouse/store)."},
264
+ {"name": "on_hand_qty", "type": "INT", "description": "Physical units currently available."},
265
+ {"name": "reserved_qty", "type": "INT", "description": "Units reserved for open orders."},
266
+ {"name": "reorder_point", "type": "INT", "description": "Threshold to trigger replenishment."},
267
+ {"name": "last_restock_date", "type": "DATE", "description": "Date of last inbound stock."}
268
+ ]
269
+ }
270
+ ]
utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import List, Tuple, Dict, Any
6
+
7
+ import torch
8
+ from huggingface_hub import hf_hub_download
9
+ from sentence_transformers import SentenceTransformer
10
+ from sentence_transformers.util import cos_sim
11
+
12
+ _MODEL_ID = "mabosaimi/bge-m3-text2tables"
13
+
14
+ model: SentenceTransformer = SentenceTransformer(_MODEL_ID)
15
+
16
+ _corpus_text_file = hf_hub_download(repo_id=_MODEL_ID, filename="corpus_texts.json")
17
+ with open(_corpus_text_file, "r", encoding="utf-8") as _f:
18
+ corpus_texts: List[str] = json.load(_f)
19
+
20
+ _corpus_emb_file = hf_hub_download(repo_id=_MODEL_ID, filename="corpus_embeddings.pt")
21
+ corpus_embeddings: torch.Tensor = torch.load(_corpus_emb_file, map_location="cpu")
22
+
23
+ _schemas_PATH = Path(__file__).parent / "schemas.json"
24
+ if _schemas_PATH.exists():
25
+ with open(_schemas_PATH, "r", encoding="utf-8") as _cf:
26
+ schemas: List[Dict[str, Any]] = json.load(_cf)
27
+ else:
28
+ schemas = []
29
+
30
+
31
+ def get_model_id() -> str:
32
+ """Return the identifier of the embedding model in use.
33
+
34
+ This intentionally hides low-level model details from API consumers while
35
+ allowing health/diagnostics endpoints to expose basic service info.
36
+ """
37
+
38
+ return _MODEL_ID
39
+
40
+
41
+ def get_corpus_size() -> int:
42
+ """Return the number of entries in the fixed metadata corpus."""
43
+
44
+ return len(corpus_texts)
45
+
46
+
47
+ def preprocess_text(query: str) -> str:
48
+ """Preprocess a natural language string by stripping whitespace.
49
+
50
+ Inputs:
51
+ - query: Natural language string to be preprocessed.
52
+
53
+ Returns:
54
+ - The preprocessed string.
55
+ """
56
+ return query.strip()
57
+
58
+
59
+ def encode_text(query: str) -> torch.Tensor:
60
+ """Encode a natural language query into an embedding tensor.
61
+
62
+ Inputs:
63
+ - query: Natural language string to be embedded.
64
+
65
+ Returns:
66
+ - A 1 x D torch.Tensor representing the normalized embedding of the query.
67
+ """
68
+ query = preprocess_text(query)
69
+ return model.encode(query, convert_to_tensor=True, normalize_embeddings=True)
70
+
71
+
72
+ def semantic_search(query: str, top_k: int = 5) -> List[Tuple[float, str, int]]:
73
+ """Compute semantic similarity between a query and the stored corpus.
74
+
75
+ Inputs:
76
+ - query: Natural language search string.
77
+ - top_k: Maximum number of results to return (capped at corpus size).
78
+
79
+ Returns:
80
+ - A list of tuples (score, text, index) sorted by descending similarity,
81
+ where:
82
+ - score is a float cosine similarity.
83
+ - text is the matched corpus entry.
84
+ - index is the integer position in the corpus (stable identifier).
85
+ """
86
+
87
+ query_embedding = encode_text(query)
88
+ scores = cos_sim(query_embedding, corpus_embeddings)[0]
89
+ k = min(max(top_k, 1), len(corpus_texts))
90
+ values, indices = torch.topk(scores, k=k)
91
+ return [
92
+ (float(values[i]), corpus_texts[int(indices[i])], int(indices[i]))
93
+ for i in range(len(values))
94
+ ]
95
+
96
+
97
+ def get_schemas(include_columns: bool = False) -> List[Dict[str, Any]]:
98
+ """Return the local schemas.
99
+
100
+ Inputs:
101
+ - include_columns: When True, include full column metadata; otherwise
102
+ return a minimal view with table name and description only.
103
+
104
+ Returns:
105
+ - List of table dicts. If include_columns is False, each dict contains
106
+ {"table", "description"}. If True, it includes the original structure.
107
+ """
108
+
109
+ if not schemas:
110
+ return []
111
+ if include_columns:
112
+ return schemas
113
+ return [
114
+ {"table": t["table"], "description": t.get("description", "")} for t in schemas
115
+ ]