Spaces:
Sleeping
Sleeping
| import json | |
| from typing import Any, Dict, List, Literal, Optional | |
| import torch | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel, Field | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| MODEL_REPO = "aagzamov/search-query-parser" | |
| TOKENIZER_REPO = "Qwen/Qwen2.5-1.5B-Instruct" | |
| MAX_NEW_TOKENS = 192 | |
| SortType = Literal["relevance", "price_asc", "price_desc", "newest"] | |
| class Filters(BaseModel): | |
| brand: Optional[List[str]] = None | |
| category: Optional[List[str]] = None | |
| color: Optional[List[str]] = None | |
| size_eu: Optional[List[int]] = None | |
| price_min: Optional[float] = None | |
| price_max: Optional[float] = None | |
| in_stock: Optional[bool] = None | |
| shipping: Optional[List[str]] = None | |
| class SearchIntent(BaseModel): | |
| query: str | |
| filters: Filters = Field(default_factory=Filters) | |
| sort: SortType = "relevance" | |
| page: int = 1 | |
| limit: int = 24 | |
| class IntentRequest(BaseModel): | |
| text: str | |
| page: int = 1 | |
| limit: int = 24 | |
| sort: Optional[SortType] = None | |
| SYSTEM = ( | |
| "You convert a shopping search text into a JSON object.\n" | |
| "Return ONLY valid JSON. No markdown. No extra keys.\n" | |
| "Must include: query, filters, sort, page, limit.\n" | |
| "filters may include only: brand, category, color, size_eu, price_min, price_max, in_stock, shipping.\n" | |
| "sort must be one of: relevance, price_asc, price_desc, newest.\n" | |
| ) | |
| def to_chat_messages(user_query: str): | |
| return [ | |
| {"role": "system", "content": SYSTEM}, | |
| {"role": "user", "content": f"Query: {user_query}\nReturn JSON now."}, | |
| ] | |
| def try_parse_json(text: str) -> Optional[Dict[str, Any]]: | |
| s = text.strip() | |
| a = s.find("{") | |
| b = s.rfind("}") | |
| if a == -1 or b == -1 or b <= a: | |
| return None | |
| chunk = s[a : b + 1] | |
| try: | |
| return json.loads(chunk) | |
| except Exception: | |
| return None | |
| def normalize(obj: Dict[str, Any], page: int, limit: int, sort_override: Optional[str]) -> Dict[str, Any]: | |
| out: Dict[str, Any] = {} | |
| out["query"] = str(obj.get("query", "")).strip() | |
| filters = obj.get("filters", {}) | |
| if not isinstance(filters, dict): | |
| filters = {} | |
| allowed = {"brand","category","color","size_eu","price_min","price_max","in_stock","shipping"} | |
| out["filters"] = {k: v for k, v in filters.items() if k in allowed} | |
| sort = sort_override or obj.get("sort", "relevance") | |
| if sort not in {"relevance", "price_asc", "price_desc", "newest"}: | |
| sort = "relevance" | |
| out["sort"] = sort | |
| out["page"] = max(1, int(page)) | |
| out["limit"] = max(1, min(100, int(limit))) | |
| return out | |
| app = FastAPI(title="Search Query Parser API", version="1.0.0") | |
| tokenizer = None | |
| model = None | |
| def load_model(): | |
| global tokenizer, model | |
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_REPO, use_fast=False) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_REPO, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| ) | |
| model.eval() | |
| def generate(text: str) -> str: | |
| messages = to_chat_messages(text) | |
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| do_sample=False, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| gen_ids = out[0][inputs["input_ids"].shape[1]:] | |
| return tokenizer.decode(gen_ids, skip_special_tokens=True) | |
| def startup(): | |
| load_model() | |
| def health(): | |
| return { | |
| "ok": True, | |
| "model_repo": MODEL_REPO, | |
| "cuda": torch.cuda.is_available(), | |
| "device": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None, | |
| } | |
| def intent(req: IntentRequest): | |
| raw = generate(req.text) | |
| parsed = try_parse_json(raw) | |
| if parsed is None: | |
| raise HTTPException(status_code=422, detail={"error": "invalid_json", "raw": raw}) | |
| normalized = normalize(parsed, page=req.page, limit=req.limit, sort_override=req.sort) | |
| try: | |
| return SearchIntent.model_validate(normalized) | |
| except Exception as e: | |
| raise HTTPException(status_code=422, detail={"error": "schema_validation_failed", "normalized": normalized, "raw": raw, "msg": str(e)}) | |