aagzamov commited on
Commit
b1e9ae1
·
verified ·
1 Parent(s): 0da7288

Deploy FastAPI Swagger Space (Docker)

Browse files
Files changed (4) hide show
  1. Dockerfile +18 -0
  2. README.md +8 -5
  3. app.py +139 -0
  4. requirements.txt +7 -0
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1
4
+ ENV PYTHONUNBUFFERED=1
5
+
6
+ WORKDIR /app
7
+
8
+ RUN apt-get update && apt-get install -y --no-install-recommends \
9
+ git \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ COPY requirements.txt .
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ COPY app.py .
16
+
17
+ EXPOSE 7860
18
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,11 +1,14 @@
1
  ---
2
  title: Search Query Parser
3
- emoji: 📊
4
- colorFrom: purple
5
- colorTo: pink
6
  sdk: docker
7
  pinned: false
8
- short_description: Search query parser
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
1
  ---
2
  title: Search Query Parser
3
+ emoji: 🔎
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: docker
7
  pinned: false
 
8
  ---
9
 
10
+ FastAPI Swagger API
11
+
12
+ - `/docs` Swagger UI
13
+ - `/intent` parse query into JSON
14
+ - `/health` health check
app.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Any, Dict, List, Literal, Optional
3
+
4
+ import torch
5
+ from fastapi import FastAPI, HTTPException
6
+ from pydantic import BaseModel, Field
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+ MODEL_REPO = "aagzamov/search-query-parser"
10
+ MAX_NEW_TOKENS = 192
11
+
12
+ SortType = Literal["relevance", "price_asc", "price_desc", "newest"]
13
+
14
+ class Filters(BaseModel):
15
+ brand: Optional[List[str]] = None
16
+ category: Optional[List[str]] = None
17
+ color: Optional[List[str]] = None
18
+ size_eu: Optional[List[int]] = None
19
+ price_min: Optional[float] = None
20
+ price_max: Optional[float] = None
21
+ in_stock: Optional[bool] = None
22
+ shipping: Optional[List[str]] = None
23
+
24
+ class SearchIntent(BaseModel):
25
+ query: str
26
+ filters: Filters = Field(default_factory=Filters)
27
+ sort: SortType = "relevance"
28
+ page: int = 1
29
+ limit: int = 24
30
+
31
+ class IntentRequest(BaseModel):
32
+ text: str
33
+ page: int = 1
34
+ limit: int = 24
35
+ sort: Optional[SortType] = None
36
+
37
+ SYSTEM = (
38
+ "You convert a shopping search text into a JSON object.\n"
39
+ "Return ONLY valid JSON. No markdown. No extra keys.\n"
40
+ "Must include: query, filters, sort, page, limit.\n"
41
+ "filters may include only: brand, category, color, size_eu, price_min, price_max, in_stock, shipping.\n"
42
+ "sort must be one of: relevance, price_asc, price_desc, newest.\n"
43
+ )
44
+
45
+ def to_chat_messages(user_query: str):
46
+ return [
47
+ {"role": "system", "content": SYSTEM},
48
+ {"role": "user", "content": f"Query: {user_query}\nReturn JSON now."},
49
+ ]
50
+
51
+ def try_parse_json(text: str) -> Optional[Dict[str, Any]]:
52
+ s = text.strip()
53
+ a = s.find("{")
54
+ b = s.rfind("}")
55
+ if a == -1 or b == -1 or b <= a:
56
+ return None
57
+ chunk = s[a : b + 1]
58
+ try:
59
+ return json.loads(chunk)
60
+ except Exception:
61
+ return None
62
+
63
+ def normalize(obj: Dict[str, Any], page: int, limit: int, sort_override: Optional[str]) -> Dict[str, Any]:
64
+ out: Dict[str, Any] = {}
65
+ out["query"] = str(obj.get("query", "")).strip()
66
+
67
+ filters = obj.get("filters", {})
68
+ if not isinstance(filters, dict):
69
+ filters = {}
70
+
71
+ allowed = {"brand","category","color","size_eu","price_min","price_max","in_stock","shipping"}
72
+ out["filters"] = {k: v for k, v in filters.items() if k in allowed}
73
+
74
+ sort = sort_override or obj.get("sort", "relevance")
75
+ if sort not in {"relevance", "price_asc", "price_desc", "newest"}:
76
+ sort = "relevance"
77
+ out["sort"] = sort
78
+
79
+ out["page"] = max(1, int(page))
80
+ out["limit"] = max(1, min(100, int(limit)))
81
+ return out
82
+
83
+ app = FastAPI(title="Search Query Parser API", version="1.0.0")
84
+
85
+ tokenizer = None
86
+ model = None
87
+
88
+ def load_model():
89
+ global tokenizer, model
90
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, use_fast=True)
91
+ model = AutoModelForCausalLM.from_pretrained(
92
+ MODEL_REPO,
93
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
94
+ device_map="auto" if torch.cuda.is_available() else None,
95
+ )
96
+ model.eval()
97
+
98
+ @torch.inference_mode()
99
+ def generate(text: str) -> str:
100
+ messages = to_chat_messages(text)
101
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
102
+ inputs = tokenizer(prompt, return_tensors="pt")
103
+ if torch.cuda.is_available():
104
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
105
+
106
+ out = model.generate(
107
+ **inputs,
108
+ max_new_tokens=MAX_NEW_TOKENS,
109
+ do_sample=False,
110
+ eos_token_id=tokenizer.eos_token_id,
111
+ )
112
+ gen_ids = out[0][inputs["input_ids"].shape[1]:]
113
+ return tokenizer.decode(gen_ids, skip_special_tokens=True)
114
+
115
+ @app.on_event("startup")
116
+ def startup():
117
+ load_model()
118
+
119
+ @app.get("/health")
120
+ def health():
121
+ return {
122
+ "ok": True,
123
+ "model_repo": MODEL_REPO,
124
+ "cuda": torch.cuda.is_available(),
125
+ "device": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
126
+ }
127
+
128
+ @app.post("/intent", response_model=SearchIntent)
129
+ def intent(req: IntentRequest):
130
+ raw = generate(req.text)
131
+ parsed = try_parse_json(raw)
132
+ if parsed is None:
133
+ raise HTTPException(status_code=422, detail={"error": "invalid_json", "raw": raw})
134
+
135
+ normalized = normalize(parsed, page=req.page, limit=req.limit, sort_override=req.sort)
136
+ try:
137
+ return SearchIntent.model_validate(normalized)
138
+ except Exception as e:
139
+ raise HTTPException(status_code=422, detail={"error": "schema_validation_failed", "normalized": normalized, "raw": raw, "msg": str(e)})
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi==0.115.8
2
+ uvicorn==0.34.0
3
+ pydantic==2.10.6
4
+ torch
5
+ transformers==4.48.2
6
+ safetensors
7
+ accelerate