Commit
·
64c014e
1
Parent(s):
7ee7723
Fix device placement for tokenizer outputs before model inference
Browse files- Move tokenizer outputs to model device after tokenization (line 293)
- Ensure inputs are on model device before non-streaming generate() call (line 342)
- Ensure inputs are on model device before streaming generate() call (line 439)
- Fixes device mismatch issues when using device_map='auto'
- app/providers/transformers_provider.py +42 -21
- app/routers/openai_api.py +12 -2
- tests/test_providers.py +1 -2
app/providers/transformers_provider.py
CHANGED
|
@@ -289,8 +289,11 @@ class TransformersProvider:
|
|
| 289 |
log_warning("No chat_template found, using fallback")
|
| 290 |
|
| 291 |
# Tokenize
|
| 292 |
-
# device_map="auto" handles device placement
|
| 293 |
inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
| 294 |
|
| 295 |
# Handle streaming vs non-streaming
|
| 296 |
if stream:
|
|
@@ -335,6 +338,10 @@ class TransformersProvider:
|
|
| 335 |
generation_kwargs["do_sample"] = False # Explicitly set for temperature=0
|
| 336 |
log_info(f"Set temperature from {original_temp} to 0.0 (greedy decoding) for JSON output format")
|
| 337 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
with torch.no_grad():
|
| 339 |
outputs = model.generate(
|
| 340 |
**inputs,
|
|
@@ -427,8 +434,11 @@ class TransformersProvider:
|
|
| 427 |
}
|
| 428 |
|
| 429 |
def generate():
|
|
|
|
|
|
|
|
|
|
| 430 |
with torch.no_grad():
|
| 431 |
-
model.generate(**
|
| 432 |
|
| 433 |
generation_thread = Thread(target=generate)
|
| 434 |
generation_thread.start()
|
|
@@ -523,27 +533,38 @@ class TransformersProvider:
|
|
| 523 |
|
| 524 |
return cleaned_text
|
| 525 |
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
return None
|
| 531 |
-
|
| 532 |
-
brace_count = 0
|
| 533 |
-
for i in range(brace_start, len(text)):
|
| 534 |
-
if text[i] == '{':
|
| 535 |
-
brace_count += 1
|
| 536 |
-
elif text[i] == '}':
|
| 537 |
-
brace_count -= 1
|
| 538 |
-
if brace_count == 0:
|
| 539 |
-
json_candidate = text[brace_start:i+1]
|
| 540 |
-
try:
|
| 541 |
-
json.loads(json_candidate)
|
| 542 |
-
return json_candidate
|
| 543 |
-
except json.JSONDecodeError:
|
| 544 |
-
return None
|
| 545 |
return None
|
| 546 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
def _format_tools_for_prompt(self, tools: List[Dict[str, Any]]) -> str:
|
| 548 |
"""Format tools for inclusion in system prompt."""
|
| 549 |
tools_text = (
|
|
|
|
| 289 |
log_warning("No chat_template found, using fallback")
|
| 290 |
|
| 291 |
# Tokenize
|
| 292 |
+
# Move inputs to model device (device_map="auto" handles model placement, but inputs need explicit device placement)
|
| 293 |
inputs = tokenizer(prompt, return_tensors="pt")
|
| 294 |
+
# Get model device (works with device_map="auto" by checking first parameter's device)
|
| 295 |
+
model_device = next(model.parameters()).device
|
| 296 |
+
inputs = {k: v.to(model_device) for k, v in inputs.items()}
|
| 297 |
|
| 298 |
# Handle streaming vs non-streaming
|
| 299 |
if stream:
|
|
|
|
| 338 |
generation_kwargs["do_sample"] = False # Explicitly set for temperature=0
|
| 339 |
log_info(f"Set temperature from {original_temp} to 0.0 (greedy decoding) for JSON output format")
|
| 340 |
|
| 341 |
+
# Ensure inputs are on model device before generation
|
| 342 |
+
model_device = next(model.parameters()).device
|
| 343 |
+
inputs = {k: v.to(model_device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
|
| 344 |
+
|
| 345 |
with torch.no_grad():
|
| 346 |
outputs = model.generate(
|
| 347 |
**inputs,
|
|
|
|
| 434 |
}
|
| 435 |
|
| 436 |
def generate():
|
| 437 |
+
# Ensure inputs are on model device before generation
|
| 438 |
+
model_device = next(model.parameters()).device
|
| 439 |
+
inputs_on_device = {k: v.to(model_device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
|
| 440 |
with torch.no_grad():
|
| 441 |
+
model.generate(**inputs_on_device, **generation_kwargs)
|
| 442 |
|
| 443 |
generation_thread = Thread(target=generate)
|
| 444 |
generation_thread.start()
|
|
|
|
| 533 |
|
| 534 |
return cleaned_text
|
| 535 |
|
| 536 |
+
def _extract_json_by_brace_matching(self, text: str, start_pos: int = 0) -> Optional[str]:
|
| 537 |
+
"""Extract JSON object by matching braces starting at given position."""
|
| 538 |
+
brace_start = text.find('{', start_pos)
|
| 539 |
+
if brace_start == -1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 540 |
return None
|
| 541 |
|
| 542 |
+
brace_count = 0
|
| 543 |
+
in_string = False
|
| 544 |
+
escape_next = False
|
| 545 |
+
for i in range(brace_start, len(text)):
|
| 546 |
+
if escape_next:
|
| 547 |
+
escape_next = False
|
| 548 |
+
continue
|
| 549 |
+
if text[i] == '\\':
|
| 550 |
+
escape_next = True
|
| 551 |
+
elif text[i] == '"' and not in_string:
|
| 552 |
+
in_string = True
|
| 553 |
+
elif text[i] == '"' and in_string:
|
| 554 |
+
in_string = False
|
| 555 |
+
elif text[i] == '{' and not in_string:
|
| 556 |
+
brace_count += 1
|
| 557 |
+
elif text[i] == '}' and not in_string:
|
| 558 |
+
brace_count -= 1
|
| 559 |
+
if brace_count == 0:
|
| 560 |
+
json_candidate = text[brace_start:i+1]
|
| 561 |
+
try:
|
| 562 |
+
json.loads(json_candidate)
|
| 563 |
+
return json_candidate
|
| 564 |
+
except json.JSONDecodeError:
|
| 565 |
+
return None
|
| 566 |
+
return None
|
| 567 |
+
|
| 568 |
def _format_tools_for_prompt(self, tools: List[Dict[str, Any]]) -> str:
|
| 569 |
"""Format tools for inclusion in system prompt."""
|
| 570 |
tools_text = (
|
app/routers/openai_api.py
CHANGED
|
@@ -26,8 +26,18 @@ async def get_stats():
|
|
| 26 |
Returns:
|
| 27 |
Dictionary containing request counts, token usage, and performance metrics.
|
| 28 |
"""
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
@router.post("/models/reload")
|
|
|
|
| 26 |
Returns:
|
| 27 |
Dictionary containing request counts, token usage, and performance metrics.
|
| 28 |
"""
|
| 29 |
+
try:
|
| 30 |
+
from app.utils.stats import get_stats_tracker
|
| 31 |
+
return get_stats_tracker().get_stats()
|
| 32 |
+
except Exception as e:
|
| 33 |
+
logger.error(f"Error getting stats: {str(e)}", exc_info=True)
|
| 34 |
+
return JSONResponse(
|
| 35 |
+
status_code=500,
|
| 36 |
+
content={
|
| 37 |
+
"status": "error",
|
| 38 |
+
"message": "Failed to retrieve statistics. Check logs for details.",
|
| 39 |
+
}
|
| 40 |
+
)
|
| 41 |
|
| 42 |
|
| 43 |
@router.post("/models/reload")
|
tests/test_providers.py
CHANGED
|
@@ -160,5 +160,4 @@ def test_provider_extract_json_by_brace_matching():
|
|
| 160 |
result = provider._extract_json_by_brace_matching(text)
|
| 161 |
|
| 162 |
assert result is not None
|
| 163 |
-
assert "key"
|
| 164 |
-
assert "value" in result
|
|
|
|
| 160 |
result = provider._extract_json_by_brace_matching(text)
|
| 161 |
|
| 162 |
assert result is not None
|
| 163 |
+
assert result.get("key") == "value"
|
|
|