Ashok75 commited on
Commit
6b59904
·
verified ·
1 Parent(s): b921a13

Upload server_runtime.py

Browse files
Files changed (1) hide show
  1. server_runtime.py +37 -11
server_runtime.py CHANGED
@@ -99,7 +99,16 @@ def _is_truthy(value: str) -> bool:
99
 
100
 
101
  def _format_sse_event(payload: Dict[str, Any]) -> str:
102
- return f"data: {json.dumps(payload)}\n\n"
 
 
 
 
 
 
 
 
 
103
 
104
 
105
  def _detect_concurrency(device: str) -> int:
@@ -122,7 +131,8 @@ def _detect_concurrency(device: str) -> int:
122
  return 3
123
 
124
  cpu_count = os.cpu_count() or 1
125
- return max(1, min(4, max(1, cpu_count // 2)))
 
126
 
127
 
128
  def create_hf_space_app(config: RuntimeConfig) -> FastAPI:
@@ -258,9 +268,9 @@ def create_hf_space_app(config: RuntimeConfig) -> FastAPI:
258
  break
259
 
260
  try:
261
- new_text = await asyncio.to_thread(next, stream_iter)
262
- except StopIteration:
263
- break
264
  except QueueEmpty:
265
  if generation_done.is_set():
266
  break
@@ -360,12 +370,28 @@ def create_hf_space_app(config: RuntimeConfig) -> FastAPI:
360
  if config.tokenizer_use_fast is not None:
361
  tokenizer_kwargs["use_fast"] = config.tokenizer_use_fast
362
  tokenizer = AutoTokenizer.from_pretrained(config.model_name, **tokenizer_kwargs)
363
- model = AutoModelForCausalLM.from_pretrained(
364
- config.model_name,
365
- trust_remote_code=True,
366
- torch_dtype="auto" if device == "cuda" else torch.float32,
367
- device_map="auto" if device == "cuda" else None,
368
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
 
370
  if device != "cuda":
371
  model = model.to("cpu")
 
99
 
100
 
101
  def _format_sse_event(payload: Dict[str, Any]) -> str:
102
+ event_type = str(payload.get("type", "token"))
103
+ return f"event: {event_type}\ndata: {json.dumps(payload)}\n\n"
104
+
105
+
106
+ def _read_stream_item(stream_iter) -> tuple[bool, Optional[str]]:
107
+ """Read one item from streamer iterator without leaking StopIteration across threads."""
108
+ try:
109
+ return False, next(stream_iter)
110
+ except StopIteration:
111
+ return True, None
112
 
113
 
114
  def _detect_concurrency(device: str) -> int:
 
131
  return 3
132
 
133
  cpu_count = os.cpu_count() or 1
134
+ # Conservative CPU default for large models; still within 1..4 range.
135
+ return max(1, min(4, max(1, cpu_count // 6)))
136
 
137
 
138
  def create_hf_space_app(config: RuntimeConfig) -> FastAPI:
 
268
  break
269
 
270
  try:
271
+ stream_finished, new_text = await asyncio.to_thread(_read_stream_item, stream_iter)
272
+ if stream_finished:
273
+ break
274
  except QueueEmpty:
275
  if generation_done.is_set():
276
  break
 
370
  if config.tokenizer_use_fast is not None:
371
  tokenizer_kwargs["use_fast"] = config.tokenizer_use_fast
372
  tokenizer = AutoTokenizer.from_pretrained(config.model_name, **tokenizer_kwargs)
373
+ model_load_kwargs: Dict[str, Any] = {
374
+ "trust_remote_code": True,
375
+ "device_map": "auto" if device == "cuda" else None,
376
+ }
377
+ if device == "cuda":
378
+ model_load_kwargs["dtype"] = "auto"
379
+ else:
380
+ model_load_kwargs["torch_dtype"] = torch.float32
381
+
382
+ try:
383
+ model = AutoModelForCausalLM.from_pretrained(
384
+ config.model_name,
385
+ **model_load_kwargs,
386
+ )
387
+ except TypeError:
388
+ # Backward compatibility for older transformers that do not accept `dtype`.
389
+ if "dtype" in model_load_kwargs:
390
+ model_load_kwargs["torch_dtype"] = model_load_kwargs.pop("dtype")
391
+ model = AutoModelForCausalLM.from_pretrained(
392
+ config.model_name,
393
+ **model_load_kwargs,
394
+ )
395
 
396
  if device != "cuda":
397
  model = model.to("cpu")