aagzamov's picture
Deploy FastAPI Swagger Space (Docker)
0e8822f verified
import json
from typing import Any, Dict, List, Literal, Optional
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_REPO = "aagzamov/search-query-parser"
TOKENIZER_REPO = "Qwen/Qwen2.5-1.5B-Instruct"
MAX_NEW_TOKENS = 192
SortType = Literal["relevance", "price_asc", "price_desc", "newest"]
class Filters(BaseModel):
brand: Optional[List[str]] = None
category: Optional[List[str]] = None
color: Optional[List[str]] = None
size_eu: Optional[List[int]] = None
price_min: Optional[float] = None
price_max: Optional[float] = None
in_stock: Optional[bool] = None
shipping: Optional[List[str]] = None
class SearchIntent(BaseModel):
query: str
filters: Filters = Field(default_factory=Filters)
sort: SortType = "relevance"
page: int = 1
limit: int = 24
class IntentRequest(BaseModel):
text: str
page: int = 1
limit: int = 24
sort: Optional[SortType] = None
SYSTEM = (
"You convert a shopping search text into a JSON object.\n"
"Return ONLY valid JSON. No markdown. No extra keys.\n"
"Must include: query, filters, sort, page, limit.\n"
"filters may include only: brand, category, color, size_eu, price_min, price_max, in_stock, shipping.\n"
"sort must be one of: relevance, price_asc, price_desc, newest.\n"
)
def to_chat_messages(user_query: str):
return [
{"role": "system", "content": SYSTEM},
{"role": "user", "content": f"Query: {user_query}\nReturn JSON now."},
]
def try_parse_json(text: str) -> Optional[Dict[str, Any]]:
s = text.strip()
a = s.find("{")
b = s.rfind("}")
if a == -1 or b == -1 or b <= a:
return None
chunk = s[a : b + 1]
try:
return json.loads(chunk)
except Exception:
return None
def normalize(obj: Dict[str, Any], page: int, limit: int, sort_override: Optional[str]) -> Dict[str, Any]:
out: Dict[str, Any] = {}
out["query"] = str(obj.get("query", "")).strip()
filters = obj.get("filters", {})
if not isinstance(filters, dict):
filters = {}
allowed = {"brand","category","color","size_eu","price_min","price_max","in_stock","shipping"}
out["filters"] = {k: v for k, v in filters.items() if k in allowed}
sort = sort_override or obj.get("sort", "relevance")
if sort not in {"relevance", "price_asc", "price_desc", "newest"}:
sort = "relevance"
out["sort"] = sort
out["page"] = max(1, int(page))
out["limit"] = max(1, min(100, int(limit)))
return out
app = FastAPI(title="Search Query Parser API", version="1.0.0")
tokenizer = None
model = None
def load_model():
global tokenizer, model
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_REPO, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
MODEL_REPO,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
)
model.eval()
@torch.inference_mode()
def generate(text: str) -> str:
messages = to_chat_messages(text)
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to(model.device) for k, v in inputs.items()}
out = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
eos_token_id=tokenizer.eos_token_id,
)
gen_ids = out[0][inputs["input_ids"].shape[1]:]
return tokenizer.decode(gen_ids, skip_special_tokens=True)
@app.on_event("startup")
def startup():
load_model()
@app.get("/health")
def health():
return {
"ok": True,
"model_repo": MODEL_REPO,
"cuda": torch.cuda.is_available(),
"device": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
}
@app.post("/intent", response_model=SearchIntent)
def intent(req: IntentRequest):
raw = generate(req.text)
parsed = try_parse_json(raw)
if parsed is None:
raise HTTPException(status_code=422, detail={"error": "invalid_json", "raw": raw})
normalized = normalize(parsed, page=req.page, limit=req.limit, sort_override=req.sort)
try:
return SearchIntent.model_validate(normalized)
except Exception as e:
raise HTTPException(status_code=422, detail={"error": "schema_validation_failed", "normalized": normalized, "raw": raw, "msg": str(e)})