jeanbaptdzd commited on
Commit
64c014e
·
1 Parent(s): 7ee7723

Fix device placement for tokenizer outputs before model inference

Browse files

- Move tokenizer outputs to model device after tokenization (line 293)
- Ensure inputs are on model device before non-streaming generate() call (line 342)
- Ensure inputs are on model device before streaming generate() call (line 439)
- Fixes device mismatch issues when using device_map='auto'

app/providers/transformers_provider.py CHANGED
@@ -289,8 +289,11 @@ class TransformersProvider:
289
  log_warning("No chat_template found, using fallback")
290
 
291
  # Tokenize
292
- # device_map="auto" handles device placement automatically
293
  inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
294
 
295
  # Handle streaming vs non-streaming
296
  if stream:
@@ -335,6 +338,10 @@ class TransformersProvider:
335
  generation_kwargs["do_sample"] = False # Explicitly set for temperature=0
336
  log_info(f"Set temperature from {original_temp} to 0.0 (greedy decoding) for JSON output format")
337
 
 
 
 
 
338
  with torch.no_grad():
339
  outputs = model.generate(
340
  **inputs,
@@ -427,8 +434,11 @@ class TransformersProvider:
427
  }
428
 
429
  def generate():
 
 
 
430
  with torch.no_grad():
431
- model.generate(**inputs, **generation_kwargs)
432
 
433
  generation_thread = Thread(target=generate)
434
  generation_thread.start()
@@ -523,27 +533,38 @@ class TransformersProvider:
523
 
524
  return cleaned_text
525
 
526
- def _extract_json_by_brace_matching(self, text: str, start_pos: int = 0) -> Optional[str]:
527
- """Extract JSON object by matching braces starting at given position."""
528
- brace_start = text.find('{', start_pos)
529
- if brace_start == -1:
530
- return None
531
-
532
- brace_count = 0
533
- for i in range(brace_start, len(text)):
534
- if text[i] == '{':
535
- brace_count += 1
536
- elif text[i] == '}':
537
- brace_count -= 1
538
- if brace_count == 0:
539
- json_candidate = text[brace_start:i+1]
540
- try:
541
- json.loads(json_candidate)
542
- return json_candidate
543
- except json.JSONDecodeError:
544
- return None
545
  return None
546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
  def _format_tools_for_prompt(self, tools: List[Dict[str, Any]]) -> str:
548
  """Format tools for inclusion in system prompt."""
549
  tools_text = (
 
289
  log_warning("No chat_template found, using fallback")
290
 
291
  # Tokenize
292
+ # Move inputs to model device (device_map="auto" handles model placement, but inputs need explicit device placement)
293
  inputs = tokenizer(prompt, return_tensors="pt")
294
+ # Get model device (works with device_map="auto" by checking first parameter's device)
295
+ model_device = next(model.parameters()).device
296
+ inputs = {k: v.to(model_device) for k, v in inputs.items()}
297
 
298
  # Handle streaming vs non-streaming
299
  if stream:
 
338
  generation_kwargs["do_sample"] = False # Explicitly set for temperature=0
339
  log_info(f"Set temperature from {original_temp} to 0.0 (greedy decoding) for JSON output format")
340
 
341
+ # Ensure inputs are on model device before generation
342
+ model_device = next(model.parameters()).device
343
+ inputs = {k: v.to(model_device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
344
+
345
  with torch.no_grad():
346
  outputs = model.generate(
347
  **inputs,
 
434
  }
435
 
436
  def generate():
437
+ # Ensure inputs are on model device before generation
438
+ model_device = next(model.parameters()).device
439
+ inputs_on_device = {k: v.to(model_device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
440
  with torch.no_grad():
441
+ model.generate(**inputs_on_device, **generation_kwargs)
442
 
443
  generation_thread = Thread(target=generate)
444
  generation_thread.start()
 
533
 
534
  return cleaned_text
535
 
536
+ def _extract_json_by_brace_matching(self, text: str, start_pos: int = 0) -> Optional[str]:
537
+ """Extract JSON object by matching braces starting at given position."""
538
+ brace_start = text.find('{', start_pos)
539
+ if brace_start == -1:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540
  return None
541
 
542
+ brace_count = 0
543
+ in_string = False
544
+ escape_next = False
545
+ for i in range(brace_start, len(text)):
546
+ if escape_next:
547
+ escape_next = False
548
+ continue
549
+ if text[i] == '\\':
550
+ escape_next = True
551
+ elif text[i] == '"' and not in_string:
552
+ in_string = True
553
+ elif text[i] == '"' and in_string:
554
+ in_string = False
555
+ elif text[i] == '{' and not in_string:
556
+ brace_count += 1
557
+ elif text[i] == '}' and not in_string:
558
+ brace_count -= 1
559
+ if brace_count == 0:
560
+ json_candidate = text[brace_start:i+1]
561
+ try:
562
+ json.loads(json_candidate)
563
+ return json_candidate
564
+ except json.JSONDecodeError:
565
+ return None
566
+ return None
567
+
568
  def _format_tools_for_prompt(self, tools: List[Dict[str, Any]]) -> str:
569
  """Format tools for inclusion in system prompt."""
570
  tools_text = (
app/routers/openai_api.py CHANGED
@@ -26,8 +26,18 @@ async def get_stats():
26
  Returns:
27
  Dictionary containing request counts, token usage, and performance metrics.
28
  """
29
- from app.utils.stats import get_stats_tracker
30
- return get_stats_tracker().get_stats()
 
 
 
 
 
 
 
 
 
 
31
 
32
 
33
  @router.post("/models/reload")
 
26
  Returns:
27
  Dictionary containing request counts, token usage, and performance metrics.
28
  """
29
+ try:
30
+ from app.utils.stats import get_stats_tracker
31
+ return get_stats_tracker().get_stats()
32
+ except Exception as e:
33
+ logger.error(f"Error getting stats: {str(e)}", exc_info=True)
34
+ return JSONResponse(
35
+ status_code=500,
36
+ content={
37
+ "status": "error",
38
+ "message": "Failed to retrieve statistics. Check logs for details.",
39
+ }
40
+ )
41
 
42
 
43
  @router.post("/models/reload")
tests/test_providers.py CHANGED
@@ -160,5 +160,4 @@ def test_provider_extract_json_by_brace_matching():
160
  result = provider._extract_json_by_brace_matching(text)
161
 
162
  assert result is not None
163
- assert "key" in result
164
- assert "value" in result
 
160
  result = provider._extract_json_by_brace_matching(text)
161
 
162
  assert result is not None
163
+ assert result.get("key") == "value"