celik-muhammed commited on
Commit
618ec49
·
verified ·
1 Parent(s): fd4644d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -12
app.py CHANGED
@@ -62,10 +62,10 @@ ALLOWED_MODELS: frozenset[str] = frozenset(
62
  )
63
 
64
  HTTP_TIMEOUT = httpx.Timeout(
65
- connect=10.0,
66
- read=120.0,
67
- write=30.0,
68
- pool=10.0,
69
  )
70
 
71
  # ---------------------------------------------------------------------------
@@ -80,9 +80,16 @@ app = FastAPI(
80
  redoc_url=None,
81
  )
82
 
 
 
 
 
 
 
 
83
  app.add_middleware(
84
  CORSMiddleware,
85
- allow_origins=["*"],
86
  allow_credentials=False,
87
  allow_methods=["GET", "POST", "OPTIONS"],
88
  allow_headers=["*"],
@@ -104,10 +111,14 @@ def _parse_json_body(body: bytes) -> dict[str, Any]:
104
  def _parse_model(body: bytes) -> str:
105
  """Extract model parameter and map obsolete legacy parameters cleanly."""
106
  data = _parse_json_body(body)
107
- model = str(data.get("model", "")).strip()
108
-
109
- if not model or model in ("openai/gpt-oss-20b", "scikit-plots/gpt-oss-20b"):
110
- return DEFAULT_MODEL
 
 
 
 
111
  return model
112
 
113
 
@@ -213,7 +224,10 @@ async def _forward(body: bytes, request: Request | None = None) -> Response:
213
  try:
214
  async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client:
215
  async with client.stream(
216
- "POST", url, content=updated_body, headers=headers
 
 
 
217
  ) as hf_resp:
218
 
219
  if hf_resp.status_code >= 400:
@@ -276,7 +290,11 @@ async def _forward(body: bytes, request: Request | None = None) -> Response:
276
  # Non-streaming processing block
277
  try:
278
  async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client:
279
- hf_resp = await client.post(url, content=updated_body, headers=headers)
 
 
 
 
280
  except httpx.HTTPError as exc:
281
  return _map_httpx_exception(exc)
282
 
@@ -291,7 +309,10 @@ async def _forward(body: bytes, request: Request | None = None) -> Response:
291
  return Response(
292
  content=hf_resp.content,
293
  status_code=hf_resp.status_code,
294
- media_type=hf_resp.headers.get("content-type", "application/json"),
 
 
 
295
  )
296
 
297
  # ---------------------------------------------------------------------------
 
62
  )
63
 
64
  HTTP_TIMEOUT = httpx.Timeout(
65
+ connect=float(os.environ.get("PROXY_CONNECT_TIMEOUT", "10")),
66
+ read=float(os.environ.get("PROXY_READ_TIMEOUT", "120")),
67
+ write=float(os.environ.get("PROXY_WRITE_TIMEOUT", "30")),
68
+ pool=float(os.environ.get("PROXY_POOL_TIMEOUT", "10")),
69
  )
70
 
71
  # ---------------------------------------------------------------------------
 
80
  redoc_url=None,
81
  )
82
 
83
+ # ["*"]
84
+ cors_origins = [
85
+ origin.strip()
86
+ for origin in os.environ.get("CORS_ALLOW_ORIGINS", "*").split(",")
87
+ if origin.strip()
88
+ ]
89
+
90
  app.add_middleware(
91
  CORSMiddleware,
92
+ allow_origins=cors_origins,
93
  allow_credentials=False,
94
  allow_methods=["GET", "POST", "OPTIONS"],
95
  allow_headers=["*"],
 
111
  def _parse_model(body: bytes) -> str:
112
  """Extract model parameter and map obsolete legacy parameters cleanly."""
113
  data = _parse_json_body(body)
114
+ model = str(data.get("model", "")).strip() or DEFAULT_MODEL
115
+
116
+ if model not in ALLOWED_MODELS:
117
+ raise HTTPException(
118
+ status_code=400,
119
+ detail=f"Unsupported model: {model}",
120
+ )
121
+
122
  return model
123
 
124
 
 
224
  try:
225
  async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client:
226
  async with client.stream(
227
+ "POST",
228
+ url,
229
+ content=updated_body,
230
+ headers=headers,
231
  ) as hf_resp:
232
 
233
  if hf_resp.status_code >= 400:
 
290
  # Non-streaming processing block
291
  try:
292
  async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client:
293
+ hf_resp = await client.post(
294
+ url,
295
+ content=updated_body,
296
+ headers=headers,
297
+ )
298
  except httpx.HTTPError as exc:
299
  return _map_httpx_exception(exc)
300
 
 
309
  return Response(
310
  content=hf_resp.content,
311
  status_code=hf_resp.status_code,
312
+ media_type=hf_resp.headers.get(
313
+ "content-type",
314
+ "application/json",
315
+ ),
316
  )
317
 
318
  # ---------------------------------------------------------------------------