import json from typing import Any, Dict, List, Literal, Optional import torch from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from transformers import AutoModelForCausalLM, AutoTokenizer MODEL_REPO = "aagzamov/search-query-parser" TOKENIZER_REPO = "Qwen/Qwen2.5-1.5B-Instruct" MAX_NEW_TOKENS = 192 SortType = Literal["relevance", "price_asc", "price_desc", "newest"] class Filters(BaseModel): brand: Optional[List[str]] = None category: Optional[List[str]] = None color: Optional[List[str]] = None size_eu: Optional[List[int]] = None price_min: Optional[float] = None price_max: Optional[float] = None in_stock: Optional[bool] = None shipping: Optional[List[str]] = None class SearchIntent(BaseModel): query: str filters: Filters = Field(default_factory=Filters) sort: SortType = "relevance" page: int = 1 limit: int = 24 class IntentRequest(BaseModel): text: str page: int = 1 limit: int = 24 sort: Optional[SortType] = None SYSTEM = ( "You convert a shopping search text into a JSON object.\n" "Return ONLY valid JSON. No markdown. No extra keys.\n" "Must include: query, filters, sort, page, limit.\n" "filters may include only: brand, category, color, size_eu, price_min, price_max, in_stock, shipping.\n" "sort must be one of: relevance, price_asc, price_desc, newest.\n" ) def to_chat_messages(user_query: str): return [ {"role": "system", "content": SYSTEM}, {"role": "user", "content": f"Query: {user_query}\nReturn JSON now."}, ] def try_parse_json(text: str) -> Optional[Dict[str, Any]]: s = text.strip() a = s.find("{") b = s.rfind("}") if a == -1 or b == -1 or b <= a: return None chunk = s[a : b + 1] try: return json.loads(chunk) except Exception: return None def normalize(obj: Dict[str, Any], page: int, limit: int, sort_override: Optional[str]) -> Dict[str, Any]: out: Dict[str, Any] = {} out["query"] = str(obj.get("query", "")).strip() filters = obj.get("filters", {}) if not isinstance(filters, dict): filters = {} allowed = {"brand","category","color","size_eu","price_min","price_max","in_stock","shipping"} out["filters"] = {k: v for k, v in filters.items() if k in allowed} sort = sort_override or obj.get("sort", "relevance") if sort not in {"relevance", "price_asc", "price_desc", "newest"}: sort = "relevance" out["sort"] = sort out["page"] = max(1, int(page)) out["limit"] = max(1, min(100, int(limit))) return out app = FastAPI(title="Search Query Parser API", version="1.0.0") tokenizer = None model = None def load_model(): global tokenizer, model tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_REPO, use_fast=False) model = AutoModelForCausalLM.from_pretrained( MODEL_REPO, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, ) model.eval() @torch.inference_mode() def generate(text: str) -> str: messages = to_chat_messages(text) prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer(prompt, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.to(model.device) for k, v in inputs.items()} out = model.generate( **inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, eos_token_id=tokenizer.eos_token_id, ) gen_ids = out[0][inputs["input_ids"].shape[1]:] return tokenizer.decode(gen_ids, skip_special_tokens=True) @app.on_event("startup") def startup(): load_model() @app.get("/health") def health(): return { "ok": True, "model_repo": MODEL_REPO, "cuda": torch.cuda.is_available(), "device": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None, } @app.post("/intent", response_model=SearchIntent) def intent(req: IntentRequest): raw = generate(req.text) parsed = try_parse_json(raw) if parsed is None: raise HTTPException(status_code=422, detail={"error": "invalid_json", "raw": raw}) normalized = normalize(parsed, page=req.page, limit=req.limit, sort_override=req.sort) try: return SearchIntent.model_validate(normalized) except Exception as e: raise HTTPException(status_code=422, detail={"error": "schema_validation_failed", "normalized": normalized, "raw": raw, "msg": str(e)})