jeanbaptdzd commited on
Commit
dc14519
·
1 Parent(s): 6f42b13

Refactor: Address code shortcomings and align with HF best practices

Browse files

Phase 1 - Critical Fixes:
- Fix deprecated clear_gpu_memory() calls (remove model/tokenizer params)
- Register rate limiting middleware
- Add /v1/stats endpoint
- Improve thread safety with is_model_ready()
- Apply log_level from config dynamically

Phase 2 - Remove Redundancies:
- Simplify memory management (remove redundant cleanup in inference paths)
- Remove manual HF token env var setting (HF Hub handles it)
- Remove manual chat template loading (auto-loaded in transformers 4.45.0+)
- Remove manual device management (device_map='auto' handles it)

Phase 3 - Code Quality:
- Centralize version management in app/__init__.py
- Refactor long functions with helper methods
- Simplify memory cleanup to single pass

Phase 4 - Testing & Documentation:
- Rewrite unit tests to test actual provider logic
- Add test coverage for helper methods
- Update README with improvements and HF best practices alignment

README.md CHANGED
@@ -136,6 +136,23 @@ response = client.chat.completions.create(
136
  - Development: L4x1 GPU (24GB VRAM)
137
  - Production: L40s GPU (48GB VRAM)
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  ## Development
140
 
141
  ### Local Setup
 
136
  - Development: L4x1 GPU (24GB VRAM)
137
  - Production: L40s GPU (48GB VRAM)
138
 
139
+ ## Recent Improvements
140
+
141
+ ### Code Quality & Hugging Face Best Practices Alignment
142
+
143
+ This codebase has been optimized to align with Hugging Face inference best practices:
144
+
145
+ - **Simplified Memory Management**: Removed redundant manual GPU memory cleanup - `device_map="auto"` handles this automatically
146
+ - **Streamlined Token Management**: Hugging Face Hub now auto-detects tokens from environment variables
147
+ - **Auto-Loading Chat Templates**: Leverages transformers 4.45.0+ automatic chat template loading
148
+ - **Automatic Device Placement**: Removed manual device management - `device_map="auto"` handles GPU/CPU placement
149
+ - **Improved Thread Safety**: Enhanced model access checks with thread-safe helpers
150
+ - **Centralized Version Management**: Single source of truth for API version
151
+
152
+ ### Deprecated Functions
153
+
154
+ - `clear_gpu_memory(model, tokenizer)` - Parameters deprecated, use `clear_gpu_memory()` without arguments
155
+
156
  ## Development
157
 
158
  ### Local Setup
app/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
- # empty package marker
2
 
 
 
1
+ """LLM Pro Finance API package."""
2
 
3
+ __version__ = "1.0.0"
app/main.py CHANGED
@@ -7,23 +7,29 @@ from typing import Dict
7
  from fastapi import FastAPI, status
8
  from fastapi.responses import JSONResponse
9
 
 
10
  from app.config import settings
11
  from app.middleware import api_key_guard
 
12
  from app.routers import openai_api
13
 
14
- # Configure logging
15
- logging.basicConfig(level=logging.INFO)
 
16
  logger = logging.getLogger(__name__)
17
 
18
  app = FastAPI(
19
  title="LLM Pro Finance API (Transformers)",
20
  description="OpenAI-compatible API for financial LLM inference",
21
- version="1.0.0"
22
  )
23
 
24
  # Mount routers
25
  app.include_router(openai_api.router, prefix="/v1")
26
 
 
 
 
27
  # Optional API key middleware
28
  app.middleware("http")(api_key_guard)
29
 
@@ -64,7 +70,7 @@ async def root() -> Dict[str, str]:
64
  return {
65
  "status": "ok",
66
  "service": "Qwen Open Finance R 8B Inference",
67
- "version": "1.0.0",
68
  "model": settings.model,
69
  "backend": "Transformers"
70
  }
 
7
  from fastapi import FastAPI, status
8
  from fastapi.responses import JSONResponse
9
 
10
+ from app import __version__
11
  from app.config import settings
12
  from app.middleware import api_key_guard
13
+ from app.middleware.rate_limit import rate_limit_middleware
14
  from app.routers import openai_api
15
 
16
+ # Configure logging with level from settings
17
+ log_level = getattr(logging, settings.log_level.upper())
18
+ logging.basicConfig(level=log_level)
19
  logger = logging.getLogger(__name__)
20
 
21
  app = FastAPI(
22
  title="LLM Pro Finance API (Transformers)",
23
  description="OpenAI-compatible API for financial LLM inference",
24
+ version=__version__
25
  )
26
 
27
  # Mount routers
28
  app.include_router(openai_api.router, prefix="/v1")
29
 
30
+ # Rate limiting middleware (applied first)
31
+ app.middleware("http")(rate_limit_middleware)
32
+
33
  # Optional API key middleware
34
  app.middleware("http")(api_key_guard)
35
 
 
70
  return {
71
  "status": "ok",
72
  "service": "Qwen Open Finance R 8B Inference",
73
+ "version": __version__,
74
  "model": settings.model,
75
  "backend": "Transformers"
76
  }
app/providers/transformers_provider.py CHANGED
@@ -7,7 +7,7 @@ import re
7
  from typing import Dict, Any, AsyncIterator, Union, List, Optional
8
  import asyncio
9
  from threading import Thread, Lock
10
- from huggingface_hub import login, hf_hub_download
11
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
12
 
13
  from app.utils.constants import (
@@ -40,7 +40,6 @@ logger = logging.getLogger(__name__)
40
  # Global model state
41
  model = None
42
  tokenizer = None
43
- device = "cuda" if torch.cuda.is_available() else "cpu"
44
  _init_lock = Lock()
45
  _initializing = False
46
  _initialized = False
@@ -84,7 +83,7 @@ def initialize_model(force_reload: bool = False):
84
  # Clear previous model if force reloading
85
  if force_reload and model is not None:
86
  log_info("Force reload requested, clearing existing model...", print_output=True)
87
- clear_gpu_memory(model, tokenizer)
88
  model = None
89
  tokenizer = None
90
  _initialized = False
@@ -105,18 +104,12 @@ def initialize_model(force_reload: bool = False):
105
  log_info(f"{token_source} found (length: {len(hf_token)})", print_output=True)
106
 
107
  # Authenticate with Hugging Face Hub
 
108
  try:
109
  login(token=hf_token, add_to_git_credential=False)
110
  log_info("Successfully authenticated with Hugging Face Hub", print_output=True)
111
  except Exception as e:
112
  log_warning(f"Failed to authenticate with HF Hub: {e}", print_output=True)
113
-
114
- # Set token environment variables
115
- os.environ.update({
116
- "HF_TOKEN": hf_token,
117
- "HUGGING_FACE_HUB_TOKEN": hf_token,
118
- "HF_API_TOKEN": hf_token,
119
- })
120
  else:
121
  log_warning(
122
  "No HF token found! Model download may fail if model is gated.",
@@ -124,6 +117,7 @@ def initialize_model(force_reload: bool = False):
124
  )
125
 
126
  # Load tokenizer
 
127
  log_info("Loading tokenizer...", print_output=True)
128
  tokenizer = AutoTokenizer.from_pretrained(
129
  MODEL_NAME,
@@ -132,21 +126,9 @@ def initialize_model(force_reload: bool = False):
132
  cache_dir=CACHE_DIR,
133
  )
134
 
135
- # Load custom chat template if missing
136
  if not hasattr(tokenizer, 'chat_template') or tokenizer.chat_template is None:
137
- try:
138
- template_path = hf_hub_download(
139
- repo_id=MODEL_NAME,
140
- filename="chat_template.jinja",
141
- repo_type="model",
142
- token=hf_token,
143
- cache_dir=CACHE_DIR,
144
- )
145
- with open(template_path, 'r', encoding='utf-8') as f:
146
- tokenizer.chat_template = f.read()
147
- log_info("Custom chat template applied", print_output=True)
148
- except Exception as e:
149
- log_warning(f"Could not load custom template, using default: {e}")
150
 
151
  log_info("Tokenizer loaded", print_output=True)
152
 
@@ -178,7 +160,7 @@ def initialize_model(force_reload: bool = False):
178
  error_msg = f"Error initializing model: {e}"
179
  log_error(error_msg, exc_info=True, print_output=True)
180
 
181
- clear_gpu_memory(model, tokenizer)
182
  model = None
183
  tokenizer = None
184
 
@@ -222,8 +204,8 @@ class TransformersProvider:
222
  ) -> Union[Dict[str, Any], AsyncIterator[str]]:
223
  """Handle chat completion requests."""
224
  try:
225
- # Initialize model on first use
226
- if model is None:
227
  log_info("Model not initialized, initializing now...")
228
  initialize_model()
229
  log_info("Model initialized successfully")
@@ -307,7 +289,8 @@ class TransformersProvider:
307
  log_warning("No chat_template found, using fallback")
308
 
309
  # Tokenize
310
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
311
 
312
  # Handle streaming vs non-streaming
313
  if stream:
@@ -323,110 +306,99 @@ class TransformersProvider:
323
  self, inputs, temperature: float, top_p: float, max_tokens: int, model_id: str, tools: Optional[List[Dict[str, Any]]] = None, json_output_required: bool = False
324
  ) -> Dict[str, Any]:
325
  """Generate non-streaming response."""
326
- try:
327
- # Prepare generation kwargs
328
- generation_kwargs = {
329
- "max_new_tokens": max_tokens,
330
- "temperature": temperature,
331
- "top_p": top_p,
332
- "top_k": DEFAULT_TOP_K,
333
- "do_sample": temperature > 0,
334
- "pad_token_id": PAD_TOKEN_ID,
335
- "eos_token_id": EOS_TOKENS,
336
- "repetition_penalty": REPETITION_PENALTY,
337
- "early_stopping": False,
338
- "use_cache": True,
339
- }
340
-
341
- # Note: Qwen reasoning models are designed to use reasoning tags
342
- # We cannot completely disable reasoning, but we can:
343
- # 1. Use strong prompts (already done above)
344
- # 2. Post-process to extract desired output (done in _extract_json_from_text and _parse_tool_calls)
345
- # 3. Set temperature to 0 for completely deterministic JSON output
346
- # Temperature=0 uses greedy decoding (always picks most likely token)
347
- # This maximizes consistency for structured outputs
348
- if json_output_required:
349
- # Set temperature to 0 for completely deterministic JSON output
350
- # This uses greedy decoding which is ideal for structured formats
351
- original_temp = generation_kwargs["temperature"]
352
- generation_kwargs["temperature"] = 0.0
353
- generation_kwargs["do_sample"] = False # Explicitly set for temperature=0
354
- log_info(f"Set temperature from {original_temp} to 0.0 (greedy decoding) for JSON output format")
355
-
356
- with torch.no_grad():
357
- outputs = model.generate(
358
- **inputs,
359
- **generation_kwargs,
360
- )
361
-
362
- # Extract token counts using tokenizer for accuracy
363
- # Count prompt tokens (more accurate than shape[1] as it handles special tokens correctly)
364
- prompt_tokens = len(inputs.input_ids[0])
365
- generated_ids = outputs[0][inputs.input_ids.shape[1]:]
366
- generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
367
- completion_tokens = len(generated_ids)
368
-
369
- # ✅ If JSON output is required, try to extract JSON from the response
370
- if json_output_required:
371
- generated_text = self._extract_json_from_text(generated_text)
372
-
373
- # Parse tool calls from generated text
374
- tool_calls = None
375
- if tools:
376
- tool_calls = self._parse_tool_calls(generated_text, tools)
377
- if tool_calls:
378
- log_info(f"Parsed {len(tool_calls)} tool calls from response")
379
- # Remove tool call markers from content if present
380
- generated_text = self._clean_tool_calls_from_text(generated_text)
381
-
382
- finish_reason = "tool_calls" if tool_calls else ("length" if completion_tokens >= max_tokens else "stop")
383
-
384
- log_info(f"Generated {completion_tokens} tokens (max: {max_tokens}), finish: {finish_reason}")
385
-
386
- # Record statistics
387
- stats_tracker = get_stats_tracker()
388
- stats_tracker.record_request(RequestStats(
389
- timestamp=time.time(),
390
- prompt_tokens=prompt_tokens,
391
- completion_tokens=completion_tokens,
392
- total_tokens=prompt_tokens + completion_tokens,
393
- model=model_id,
394
- finish_reason=finish_reason,
395
- ))
396
-
397
- # Build message with optional tool_calls
398
- message = {"role": "assistant", "content": generated_text if generated_text.strip() else None}
399
  if tool_calls:
400
- message["tool_calls"] = tool_calls
401
-
402
- return {
403
- "id": f"chatcmpl-{os.urandom(12).hex()}",
404
- "object": "chat.completion",
405
- "created": int(time.time()),
406
- "model": model_id,
407
- "choices": [
408
- {
409
- "index": 0,
410
- "message": message,
411
- "finish_reason": finish_reason,
412
- }
413
- ],
414
- "usage": {
415
- "prompt_tokens": prompt_tokens,
416
- "completion_tokens": completion_tokens,
417
- "total_tokens": prompt_tokens + completion_tokens,
418
- },
419
- }
420
- finally:
421
- # Clean up GPU memory
422
- if 'inputs' in locals():
423
- del inputs
424
- if 'outputs' in locals():
425
- del outputs
426
- if torch.cuda.is_available():
427
- torch.cuda.empty_cache()
428
- import gc
429
- gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
430
 
431
  async def _chat_stream(
432
  self, inputs, temperature: float, top_p: float, max_tokens: int, model_id: str, tools: Optional[List[Dict[str, Any]]] = None, json_output_required: bool = False
@@ -455,12 +427,8 @@ class TransformersProvider:
455
  }
456
 
457
  def generate():
458
- try:
459
- with torch.no_grad():
460
- model.generate(**inputs, **generation_kwargs)
461
- finally:
462
- if torch.cuda.is_available():
463
- torch.cuda.empty_cache()
464
 
465
  generation_thread = Thread(target=generate)
466
  generation_thread.start()
@@ -504,11 +472,6 @@ class TransformersProvider:
504
  model=model_id,
505
  finish_reason=finish_reason,
506
  ))
507
-
508
- if 'inputs' in locals():
509
- del inputs
510
- import gc
511
- gc.collect()
512
 
513
  # Send final chunk
514
  final_chunk = {
@@ -536,6 +499,51 @@ class TransformersProvider:
536
  prompt += "Assistant: "
537
  return prompt
538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  def _format_tools_for_prompt(self, tools: List[Dict[str, Any]]) -> str:
540
  """Format tools for inclusion in system prompt."""
541
  tools_text = (
@@ -580,18 +588,8 @@ class TransformersProvider:
580
  """Parse tool calls from generated text."""
581
  tool_calls = []
582
 
583
- # First, remove reasoning tags to get clean text
584
- cleaned_text = generated_text
585
- cleaned_text = re.sub(
586
- r'<think>.*?</think>',
587
- '',
588
- cleaned_text,
589
- flags=re.DOTALL | re.IGNORECASE
590
- )
591
- if "</think>" in cleaned_text:
592
- parts = cleaned_text.split("</think>", 1)
593
- if len(parts) > 1:
594
- cleaned_text = parts[1].strip()
595
 
596
  # Pattern to match <tool_call>...</tool_call> blocks
597
  pattern = r'<tool_call>\s*({.*?})\s*</tool_call>'
@@ -608,27 +606,22 @@ class TransformersProvider:
608
  if not matches:
609
  tool_names = [t.get("function", {}).get("name", "") for t in tools]
610
  # Look for JSON objects that might be tool calls
611
- brace_start = cleaned_text.find('{')
612
- while brace_start != -1:
613
- # Try to extract JSON object starting at this position
614
- brace_count = 0
615
- for i in range(brace_start, len(cleaned_text)):
616
- if cleaned_text[i] == '{':
617
- brace_count += 1
618
- elif cleaned_text[i] == '}':
619
- brace_count -= 1
620
- if brace_count == 0:
621
- json_candidate = cleaned_text[brace_start:i+1]
622
- try:
623
- candidate_data = json.loads(json_candidate)
624
- if "name" in candidate_data and candidate_data["name"] in tool_names:
625
- matches.append(json_candidate)
626
- break
627
- except json.JSONDecodeError:
628
- pass
629
- break
630
  # Find next {
631
- brace_start = cleaned_text.find('{', brace_start + 1)
 
 
632
 
633
  for i, match in enumerate(matches):
634
  try:
@@ -676,30 +669,7 @@ class TransformersProvider:
676
  def _extract_json_from_text(self, text: str) -> str:
677
  """Extract JSON from text, handling cases where JSON is wrapped in markdown, reasoning tags, or other text."""
678
  # Step 1: Remove reasoning tags first (Qwen reasoning models)
679
- # Handle <think> tags (Qwen reasoning format - actual tag is <think>)
680
- cleaned_text = text
681
-
682
- # Remove reasoning tags - matches <think>...</think>
683
- cleaned_text = re.sub(
684
- r'<think>.*?</think>',
685
- '',
686
- cleaned_text,
687
- flags=re.DOTALL | re.IGNORECASE
688
- )
689
-
690
- # Also handle unclosed reasoning tags (split on closing tag)
691
- if "</think>" in cleaned_text:
692
- parts = cleaned_text.split("</think>", 1)
693
- if len(parts) > 1:
694
- cleaned_text = parts[1].strip()
695
-
696
- # If still has opening tag but no closing, remove everything before first {
697
- # This handles cases where reasoning tag is not closed but JSON follows
698
- if "<think>" in cleaned_text.lower() and "{" in cleaned_text:
699
- # Find first { and take everything from there
700
- brace_pos = cleaned_text.find('{')
701
- if brace_pos != -1:
702
- cleaned_text = cleaned_text[brace_pos:]
703
 
704
  # Step 2: Try to find JSON wrapped in markdown code blocks
705
  json_code_block = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', cleaned_text, re.DOTALL)
@@ -733,24 +703,10 @@ class TransformersProvider:
733
  if best_match:
734
  return best_match.strip()
735
 
736
- # Step 4: Fallback - try to find any JSON-like structure
737
- # Look for { ... } and try to extract it, even if nested
738
- brace_start = cleaned_text.find('{')
739
- if brace_start != -1:
740
- # Find matching closing brace
741
- brace_count = 0
742
- for i in range(brace_start, len(cleaned_text)):
743
- if cleaned_text[i] == '{':
744
- brace_count += 1
745
- elif cleaned_text[i] == '}':
746
- brace_count -= 1
747
- if brace_count == 0:
748
- json_candidate = cleaned_text[brace_start:i+1]
749
- try:
750
- json.loads(json_candidate)
751
- return json_candidate.strip()
752
- except json.JSONDecodeError:
753
- break
754
 
755
  # Step 5: If no JSON found, return cleaned text (without reasoning tags)
756
  # This allows the caller to handle it or show an error
 
7
  from typing import Dict, Any, AsyncIterator, Union, List, Optional
8
  import asyncio
9
  from threading import Thread, Lock
10
+ from huggingface_hub import login
11
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
12
 
13
  from app.utils.constants import (
 
40
  # Global model state
41
  model = None
42
  tokenizer = None
 
43
  _init_lock = Lock()
44
  _initializing = False
45
  _initialized = False
 
83
  # Clear previous model if force reloading
84
  if force_reload and model is not None:
85
  log_info("Force reload requested, clearing existing model...", print_output=True)
86
+ clear_gpu_memory()
87
  model = None
88
  tokenizer = None
89
  _initialized = False
 
104
  log_info(f"{token_source} found (length: {len(hf_token)})", print_output=True)
105
 
106
  # Authenticate with Hugging Face Hub
107
+ # login() automatically handles token precedence and environment variables
108
  try:
109
  login(token=hf_token, add_to_git_credential=False)
110
  log_info("Successfully authenticated with Hugging Face Hub", print_output=True)
111
  except Exception as e:
112
  log_warning(f"Failed to authenticate with HF Hub: {e}", print_output=True)
 
 
 
 
 
 
 
113
  else:
114
  log_warning(
115
  "No HF token found! Model download may fail if model is gated.",
 
117
  )
118
 
119
  # Load tokenizer
120
+ # Modern transformers (4.45.0+) auto-load chat templates from model repo
121
  log_info("Loading tokenizer...", print_output=True)
122
  tokenizer = AutoTokenizer.from_pretrained(
123
  MODEL_NAME,
 
126
  cache_dir=CACHE_DIR,
127
  )
128
 
129
+ # Verify chat template is available (should be auto-loaded)
130
  if not hasattr(tokenizer, 'chat_template') or tokenizer.chat_template is None:
131
+ log_warning("Chat template not found - will use fallback formatting")
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  log_info("Tokenizer loaded", print_output=True)
134
 
 
160
  error_msg = f"Error initializing model: {e}"
161
  log_error(error_msg, exc_info=True, print_output=True)
162
 
163
+ clear_gpu_memory()
164
  model = None
165
  tokenizer = None
166
 
 
204
  ) -> Union[Dict[str, Any], AsyncIterator[str]]:
205
  """Handle chat completion requests."""
206
  try:
207
+ # Initialize model on first use (thread-safe check)
208
+ if not is_model_ready():
209
  log_info("Model not initialized, initializing now...")
210
  initialize_model()
211
  log_info("Model initialized successfully")
 
289
  log_warning("No chat_template found, using fallback")
290
 
291
  # Tokenize
292
+ # device_map="auto" handles device placement automatically
293
+ inputs = tokenizer(prompt, return_tensors="pt")
294
 
295
  # Handle streaming vs non-streaming
296
  if stream:
 
306
  self, inputs, temperature: float, top_p: float, max_tokens: int, model_id: str, tools: Optional[List[Dict[str, Any]]] = None, json_output_required: bool = False
307
  ) -> Dict[str, Any]:
308
  """Generate non-streaming response."""
309
+ # Prepare generation kwargs
310
+ generation_kwargs = {
311
+ "max_new_tokens": max_tokens,
312
+ "temperature": temperature,
313
+ "top_p": top_p,
314
+ "top_k": DEFAULT_TOP_K,
315
+ "do_sample": temperature > 0,
316
+ "pad_token_id": PAD_TOKEN_ID,
317
+ "eos_token_id": EOS_TOKENS,
318
+ "repetition_penalty": REPETITION_PENALTY,
319
+ "early_stopping": False,
320
+ "use_cache": True,
321
+ }
322
+
323
+ # Note: Qwen reasoning models are designed to use reasoning tags
324
+ # We cannot completely disable reasoning, but we can:
325
+ # 1. Use strong prompts (already done above)
326
+ # 2. Post-process to extract desired output (done in _extract_json_from_text and _parse_tool_calls)
327
+ # 3. Set temperature to 0 for completely deterministic JSON output
328
+ # Temperature=0 uses greedy decoding (always picks most likely token)
329
+ # This maximizes consistency for structured outputs
330
+ if json_output_required:
331
+ # Set temperature to 0 for completely deterministic JSON output
332
+ # This uses greedy decoding which is ideal for structured formats
333
+ original_temp = generation_kwargs["temperature"]
334
+ generation_kwargs["temperature"] = 0.0
335
+ generation_kwargs["do_sample"] = False # Explicitly set for temperature=0
336
+ log_info(f"Set temperature from {original_temp} to 0.0 (greedy decoding) for JSON output format")
337
+
338
+ with torch.no_grad():
339
+ outputs = model.generate(
340
+ **inputs,
341
+ **generation_kwargs,
342
+ )
343
+
344
+ # Extract token counts using tokenizer for accuracy
345
+ # Count prompt tokens (more accurate than shape[1] as it handles special tokens correctly)
346
+ prompt_tokens = len(inputs.input_ids[0])
347
+ generated_ids = outputs[0][inputs.input_ids.shape[1]:]
348
+ generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
349
+ completion_tokens = len(generated_ids)
350
+
351
+ # ✅ If JSON output is required, try to extract JSON from the response
352
+ if json_output_required:
353
+ generated_text = self._extract_json_from_text(generated_text)
354
+
355
+ # ✅ Parse tool calls from generated text
356
+ tool_calls = None
357
+ if tools:
358
+ tool_calls = self._parse_tool_calls(generated_text, tools)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  if tool_calls:
360
+ log_info(f"Parsed {len(tool_calls)} tool calls from response")
361
+ # Remove tool call markers from content if present
362
+ generated_text = self._clean_tool_calls_from_text(generated_text)
363
+
364
+ finish_reason = "tool_calls" if tool_calls else ("length" if completion_tokens >= max_tokens else "stop")
365
+
366
+ log_info(f"Generated {completion_tokens} tokens (max: {max_tokens}), finish: {finish_reason}")
367
+
368
+ # Record statistics
369
+ stats_tracker = get_stats_tracker()
370
+ stats_tracker.record_request(RequestStats(
371
+ timestamp=time.time(),
372
+ prompt_tokens=prompt_tokens,
373
+ completion_tokens=completion_tokens,
374
+ total_tokens=prompt_tokens + completion_tokens,
375
+ model=model_id,
376
+ finish_reason=finish_reason,
377
+ ))
378
+
379
+ # Build message with optional tool_calls
380
+ message = {"role": "assistant", "content": generated_text if generated_text.strip() else None}
381
+ if tool_calls:
382
+ message["tool_calls"] = tool_calls
383
+
384
+ return {
385
+ "id": f"chatcmpl-{os.urandom(12).hex()}",
386
+ "object": "chat.completion",
387
+ "created": int(time.time()),
388
+ "model": model_id,
389
+ "choices": [
390
+ {
391
+ "index": 0,
392
+ "message": message,
393
+ "finish_reason": finish_reason,
394
+ }
395
+ ],
396
+ "usage": {
397
+ "prompt_tokens": prompt_tokens,
398
+ "completion_tokens": completion_tokens,
399
+ "total_tokens": prompt_tokens + completion_tokens,
400
+ },
401
+ }
402
 
403
  async def _chat_stream(
404
  self, inputs, temperature: float, top_p: float, max_tokens: int, model_id: str, tools: Optional[List[Dict[str, Any]]] = None, json_output_required: bool = False
 
427
  }
428
 
429
  def generate():
430
+ with torch.no_grad():
431
+ model.generate(**inputs, **generation_kwargs)
 
 
 
 
432
 
433
  generation_thread = Thread(target=generate)
434
  generation_thread.start()
 
472
  model=model_id,
473
  finish_reason=finish_reason,
474
  ))
 
 
 
 
 
475
 
476
  # Send final chunk
477
  final_chunk = {
 
499
  prompt += "Assistant: "
500
  return prompt
501
 
502
+ def _remove_reasoning_tags(self, text: str) -> str:
503
+ """Remove Qwen reasoning tags from text."""
504
+ # Remove reasoning tags - matches <think>...</think>
505
+ cleaned_text = re.sub(
506
+ r'<think>.*?</think>',
507
+ '',
508
+ text,
509
+ flags=re.DOTALL | re.IGNORECASE
510
+ )
511
+
512
+ # Handle unclosed reasoning tags (split on closing tag)
513
+ if "</think>" in cleaned_text:
514
+ parts = cleaned_text.split("</think>", 1)
515
+ if len(parts) > 1:
516
+ cleaned_text = parts[1].strip()
517
+
518
+ # If still has opening tag but no closing, remove everything before first {
519
+ if "<think>" in cleaned_text.lower() and "{" in cleaned_text:
520
+ brace_pos = cleaned_text.find('{')
521
+ if brace_pos != -1:
522
+ cleaned_text = cleaned_text[brace_pos:]
523
+
524
+ return cleaned_text
525
+
526
+ def _extract_json_by_brace_matching(self, text: str, start_pos: int = 0) -> Optional[str]:
527
+ """Extract JSON object by matching braces starting at given position."""
528
+ brace_start = text.find('{', start_pos)
529
+ if brace_start == -1:
530
+ return None
531
+
532
+ brace_count = 0
533
+ for i in range(brace_start, len(text)):
534
+ if text[i] == '{':
535
+ brace_count += 1
536
+ elif text[i] == '}':
537
+ brace_count -= 1
538
+ if brace_count == 0:
539
+ json_candidate = text[brace_start:i+1]
540
+ try:
541
+ json.loads(json_candidate)
542
+ return json_candidate
543
+ except json.JSONDecodeError:
544
+ return None
545
+ return None
546
+
547
  def _format_tools_for_prompt(self, tools: List[Dict[str, Any]]) -> str:
548
  """Format tools for inclusion in system prompt."""
549
  tools_text = (
 
588
  """Parse tool calls from generated text."""
589
  tool_calls = []
590
 
591
+ # Remove reasoning tags to get clean text
592
+ cleaned_text = self._remove_reasoning_tags(generated_text)
 
 
 
 
 
 
 
 
 
 
593
 
594
  # Pattern to match <tool_call>...</tool_call> blocks
595
  pattern = r'<tool_call>\s*({.*?})\s*</tool_call>'
 
606
  if not matches:
607
  tool_names = [t.get("function", {}).get("name", "") for t in tools]
608
  # Look for JSON objects that might be tool calls
609
+ brace_start = 0
610
+ while True:
611
+ json_candidate = self._extract_json_by_brace_matching(cleaned_text, brace_start)
612
+ if json_candidate is None:
613
+ break
614
+ try:
615
+ candidate_data = json.loads(json_candidate)
616
+ if "name" in candidate_data and candidate_data["name"] in tool_names:
617
+ matches.append(json_candidate)
618
+ break
619
+ except json.JSONDecodeError:
620
+ pass
 
 
 
 
 
 
 
621
  # Find next {
622
+ brace_start = cleaned_text.find('{', cleaned_text.find(json_candidate) + len(json_candidate))
623
+ if brace_start == -1:
624
+ break
625
 
626
  for i, match in enumerate(matches):
627
  try:
 
669
  def _extract_json_from_text(self, text: str) -> str:
670
  """Extract JSON from text, handling cases where JSON is wrapped in markdown, reasoning tags, or other text."""
671
  # Step 1: Remove reasoning tags first (Qwen reasoning models)
672
+ cleaned_text = self._remove_reasoning_tags(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
673
 
674
  # Step 2: Try to find JSON wrapped in markdown code blocks
675
  json_code_block = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', cleaned_text, re.DOTALL)
 
703
  if best_match:
704
  return best_match.strip()
705
 
706
+ # Step 4: Fallback - try to find any JSON-like structure using brace matching
707
+ json_candidate = self._extract_json_by_brace_matching(cleaned_text)
708
+ if json_candidate:
709
+ return json_candidate.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
710
 
711
  # Step 5: If no JSON found, return cleaned text (without reasoning tags)
712
  # This allows the caller to handle it or show an error
app/routers/openai_api.py CHANGED
@@ -19,6 +19,17 @@ async def list_models_endpoint():
19
  return await list_models()
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
22
  @router.post("/models/reload")
23
  async def reload_model(force: bool = Query(False, description="Force reload from Hugging Face Hub")):
24
  """
 
19
  return await list_models()
20
 
21
 
22
+ @router.get("/stats")
23
+ async def get_stats():
24
+ """Get API usage statistics.
25
+
26
+ Returns:
27
+ Dictionary containing request counts, token usage, and performance metrics.
28
+ """
29
+ from app.utils.stats import get_stats_tracker
30
+ return get_stats_tracker().get_stats()
31
+
32
+
33
  @router.post("/models/reload")
34
  async def reload_model(force: bool = Query(False, description="Force reload from Hugging Face Hub")):
35
  """
app/utils/memory.py CHANGED
@@ -41,14 +41,9 @@ def clear_gpu_memory(model: Optional[Any] = None, tokenizer: Optional[Any] = Non
41
  if not torch.cuda.is_available():
42
  return
43
 
44
- # Clear CUDA cache
 
45
  torch.cuda.empty_cache()
46
  torch.cuda.synchronize()
47
  gc.collect()
48
-
49
- # Force multiple garbage collection passes
50
- for _ in range(3):
51
- gc.collect()
52
- if torch.cuda.is_available():
53
- torch.cuda.empty_cache()
54
 
 
41
  if not torch.cuda.is_available():
42
  return
43
 
44
+ # Clear CUDA cache and run garbage collection
45
+ # Single pass is sufficient with modern PyTorch and device_map="auto"
46
  torch.cuda.empty_cache()
47
  torch.cuda.synchronize()
48
  gc.collect()
 
 
 
 
 
 
49
 
tests/test_providers.py CHANGED
@@ -1,51 +1,164 @@
 
 
1
  import pytest
2
- from unittest.mock import patch, AsyncMock
3
- import httpx
4
 
5
- from app.providers.transformers_provider import list_models, chat
6
 
7
 
8
  @pytest.mark.asyncio
9
  async def test_list_models_success():
10
  """Test successful model listing."""
11
- mock_response = {"data": [{"id": "test-model"}]}
12
 
13
- with patch('httpx.AsyncClient') as mock_client:
14
- mock_response_obj = AsyncMock()
15
- mock_response_obj.json.return_value = mock_response
16
- mock_response_obj.raise_for_status.return_value = None
17
-
18
- mock_client.return_value.__aenter__.return_value.get.return_value = mock_response_obj
19
-
20
- result = await list_models()
21
- assert result == mock_response
22
 
23
 
24
  @pytest.mark.asyncio
25
- async def test_chat_success():
26
- """Test successful chat completion."""
27
- payload = {"model": "test", "messages": [{"role": "user", "content": "hello"}]}
28
- mock_response = {"choices": [{"message": {"content": "hi"}}]}
29
-
30
- with patch('httpx.AsyncClient') as mock_client:
31
- mock_response_obj = AsyncMock()
32
- mock_response_obj.json.return_value = mock_response
33
- mock_response_obj.raise_for_status.return_value = None
34
-
35
- mock_client.return_value.__aenter__.return_value.post.return_value = mock_response_obj
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  result = await chat(payload, stream=False)
38
- assert result == mock_response
 
 
 
 
 
 
39
 
40
 
41
  @pytest.mark.asyncio
42
- async def test_chat_stream():
43
  """Test chat completion with streaming."""
44
- payload = {"model": "test", "messages": [{"role": "user", "content": "hello"}]}
45
- mock_stream = AsyncMock()
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- with patch('httpx.AsyncClient') as mock_client:
48
- mock_client.return_value.__aenter__.return_value.stream.return_value = mock_stream
 
 
49
 
50
  result = await chat(payload, stream=True)
51
- assert result == mock_stream
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for Transformers provider."""
2
+
3
  import pytest
4
+ from unittest.mock import patch, MagicMock, AsyncMock
5
+ import torch
6
 
7
+ from app.providers.transformers_provider import list_models, chat, is_model_ready, TransformersProvider
8
 
9
 
10
  @pytest.mark.asyncio
11
  async def test_list_models_success():
12
  """Test successful model listing."""
13
+ result = await list_models()
14
 
15
+ assert "object" in result
16
+ assert result["object"] == "list"
17
+ assert "data" in result
18
+ assert len(result["data"]) > 0
19
+ assert result["data"][0]["object"] == "model"
 
 
 
 
20
 
21
 
22
  @pytest.mark.asyncio
23
+ async def test_list_models_structure():
24
+ """Test model listing returns correct structure."""
25
+ result = await list_models()
26
+
27
+ model = result["data"][0]
28
+ assert "id" in model
29
+ assert "object" in model
30
+ assert "owned_by" in model
31
+ assert model["object"] == "model"
32
+
33
+
34
+ @pytest.mark.asyncio
35
+ async def test_chat_with_mock_model():
36
+ """Test chat completion with mocked model."""
37
+ payload = {
38
+ "model": "test-model",
39
+ "messages": [{"role": "user", "content": "hello"}],
40
+ "temperature": 0.7,
41
+ "max_tokens": 100
42
+ }
43
+
44
+ # Mock the global model and tokenizer
45
+ mock_tokenizer = MagicMock()
46
+ mock_tokenizer.apply_chat_template.return_value = "formatted prompt"
47
+ mock_tokenizer.encode.return_value = [1, 2, 3]
48
+ mock_tokenizer.decode.return_value = "test response"
49
+ mock_tokenizer.__call__.return_value = {
50
+ "input_ids": torch.tensor([[1, 2, 3]]),
51
+ "attention_mask": torch.tensor([[1, 1, 1]])
52
+ }
53
+
54
+ mock_model = MagicMock()
55
+ mock_outputs = MagicMock()
56
+ mock_outputs[0] = torch.tensor([[1, 2, 3, 4, 5]])
57
+ mock_model.generate.return_value = mock_outputs
58
+ mock_model.get_input_embeddings.return_value.num_embeddings = 1000
59
+
60
+ with patch('app.providers.transformers_provider.model', mock_model), \
61
+ patch('app.providers.transformers_provider.tokenizer', mock_tokenizer), \
62
+ patch('app.providers.transformers_provider.is_model_ready', return_value=True), \
63
+ patch('app.providers.transformers_provider._initialized', True):
64
 
65
  result = await chat(payload, stream=False)
66
+
67
+ assert "id" in result
68
+ assert "object" in result
69
+ assert result["object"] == "chat.completion"
70
+ assert "choices" in result
71
+ assert len(result["choices"]) > 0
72
+ assert "usage" in result
73
 
74
 
75
  @pytest.mark.asyncio
76
+ async def test_chat_streaming():
77
  """Test chat completion with streaming."""
78
+ payload = {
79
+ "model": "test-model",
80
+ "messages": [{"role": "user", "content": "hello"}],
81
+ "stream": True
82
+ }
83
+
84
+ # Mock for streaming
85
+ mock_tokenizer = MagicMock()
86
+ mock_tokenizer.apply_chat_template.return_value = "formatted prompt"
87
+ mock_tokenizer.__call__.return_value = {
88
+ "input_ids": torch.tensor([[1, 2, 3]]),
89
+ "attention_mask": torch.tensor([[1, 1, 1]])
90
+ }
91
 
92
+ with patch('app.providers.transformers_provider.model', MagicMock()), \
93
+ patch('app.providers.transformers_provider.tokenizer', mock_tokenizer), \
94
+ patch('app.providers.transformers_provider.is_model_ready', return_value=True), \
95
+ patch('app.providers.transformers_provider._initialized', True):
96
 
97
  result = await chat(payload, stream=True)
98
+
99
+ # Should return an async iterator
100
+ assert hasattr(result, '__aiter__')
101
+
102
+
103
+ def test_is_model_ready_false_when_not_initialized():
104
+ """Test is_model_ready returns False when model not initialized."""
105
+ with patch('app.providers.transformers_provider._initialized', False), \
106
+ patch('app.providers.transformers_provider.model', None), \
107
+ patch('app.providers.transformers_provider.tokenizer', None):
108
+
109
+ assert is_model_ready() is False
110
+
111
+
112
+ def test_is_model_ready_true_when_initialized():
113
+ """Test is_model_ready returns True when model is initialized."""
114
+ mock_model = MagicMock()
115
+ mock_tokenizer = MagicMock()
116
+
117
+ with patch('app.providers.transformers_provider._initialized', True), \
118
+ patch('app.providers.transformers_provider.model', mock_model), \
119
+ patch('app.providers.transformers_provider.tokenizer', mock_tokenizer):
120
+
121
+ assert is_model_ready() is True
122
+
123
+
124
+ def test_provider_format_tools_for_prompt():
125
+ """Test tool formatting for prompt."""
126
+ provider = TransformersProvider()
127
+ tools = [
128
+ {
129
+ "function": {
130
+ "name": "test_tool",
131
+ "description": "A test tool",
132
+ "parameters": {"type": "object", "properties": {}}
133
+ }
134
+ }
135
+ ]
136
+
137
+ result = provider._format_tools_for_prompt(tools)
138
+
139
+ assert "test_tool" in result
140
+ assert "CRITICAL" in result
141
+ assert "<tool_call>" in result
142
+
143
+
144
+ def test_provider_remove_reasoning_tags():
145
+ """Test reasoning tag removal."""
146
+ provider = TransformersProvider()
147
+
148
+ text_with_tags = "<think>Some reasoning</think>Actual answer"
149
+ result = provider._remove_reasoning_tags(text_with_tags)
150
+
151
+ assert "<think>" not in result
152
+ assert "Actual answer" in result
153
+
154
+
155
+ def test_provider_extract_json_by_brace_matching():
156
+ """Test JSON extraction by brace matching."""
157
+ provider = TransformersProvider()
158
+
159
+ text = "Some text {\"key\": \"value\"} more text"
160
+ result = provider._extract_json_by_brace_matching(text)
161
+
162
+ assert result is not None
163
+ assert "key" in result
164
+ assert "value" in result