File size: 32,452 Bytes
772dd21
da484d7
16c2a22
 
9c71bb7
895a63f
 
772dd21
5ac5a91
dc14519
90a906d
6851411
16c2a22
 
 
 
 
 
 
 
 
 
 
67befa7
 
16c2a22
 
 
 
 
 
 
 
 
 
67befa7
16c2a22
 
 
 
9c71bb7
 
16c2a22
 
 
5ac5a91
6851411
16c2a22
 
 
 
 
 
da484d7
5ac5a91
da484d7
 
5ac5a91
772dd21
16c2a22
 
5ac5a91
 
 
 
16c2a22
5ac5a91
da484d7
16c2a22
5ac5a91
16c2a22
5ac5a91
67befa7
 
5ac5a91
 
 
67befa7
16c2a22
5ac5a91
 
772dd21
16c2a22
 
 
dc14519
16c2a22
 
 
 
5ac5a91
 
16c2a22
da484d7
5ac5a91
772dd21
 
16c2a22
da484d7
16c2a22
 
da484d7
5ac5a91
16c2a22
5ac5a91
 
dc14519
5ac5a91
 
16c2a22
5ac5a91
16c2a22
5ac5a91
16c2a22
 
 
5ac5a91
16c2a22
 
dc14519
16c2a22
 
 
 
 
 
 
 
dc14519
16c2a22
dc14519
16c2a22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc14519
16c2a22
 
 
 
 
 
 
9db586c
16c2a22
 
 
5ac5a91
 
6851411
 
9c71bb7
16c2a22
 
772dd21
 
 
 
7239fe3
772dd21
 
 
 
16c2a22
772dd21
 
 
7239fe3
16c2a22
772dd21
 
 
 
 
16c2a22
 
 
 
772dd21
dc14519
 
16c2a22
9c71bb7
16c2a22
772dd21
 
16c2a22
 
 
895a63f
 
a82e45b
 
 
 
 
 
5ac5a91
16c2a22
 
 
5ac5a91
a82e45b
 
 
 
 
 
 
 
895a63f
 
 
 
 
 
 
 
 
 
 
 
 
 
a82e45b
 
 
cb7f3d3
 
d39e295
cb7f3d3
 
 
 
 
 
 
 
 
a82e45b
 
 
 
 
 
 
 
 
16c2a22
9c71bb7
 
 
 
16c2a22
9c71bb7
16c2a22
ef2ab5b
 
16c2a22
9c71bb7
 
16c2a22
772dd21
9c71bb7
64c014e
dc14519
64c014e
 
 
772dd21
da484d7
 
a82e45b
da484d7
a82e45b
16c2a22
 
 
 
 
 
a82e45b
16c2a22
 
dc14519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64c014e
 
 
 
dc14519
 
 
 
 
 
 
 
7239fe3
 
dc14519
 
 
 
 
 
 
 
 
 
 
895a63f
dc14519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7239fe3
 
 
 
dc14519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6393558
7239fe3
dc14519
 
772dd21
16c2a22
a82e45b
16c2a22
 
da484d7
 
 
67befa7
7239fe3
67befa7
 
 
9c71bb7
da484d7
9c71bb7
 
 
 
 
 
78f67d6
d31f411
16c2a22
 
9c71bb7
da484d7
9c71bb7
64c014e
 
 
dc14519
64c014e
da484d7
9c71bb7
 
 
 
 
67befa7
9c71bb7
 
 
 
 
 
 
 
16c2a22
 
9c71bb7
16c2a22
9c71bb7
16c2a22
 
9c71bb7
 
67befa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da484d7
7239fe3
da484d7
 
 
 
9c71bb7
7239fe3
da484d7
16c2a22
da484d7
 
772dd21
16c2a22
772dd21
 
 
 
 
 
 
 
 
 
 
 
895a63f
dc14519
 
7239fe3
dc14519
 
 
7239fe3
dc14519
 
 
7239fe3
 
 
 
 
dc14519
7239fe3
 
 
 
 
dc14519
7239fe3
dc14519
7239fe3
 
 
 
dc14519
 
7239fe3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
895a63f
 
cb7f3d3
 
 
 
 
 
895a63f
 
 
 
 
 
cb7f3d3
895a63f
 
 
cb7f3d3
895a63f
 
cb7f3d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
895a63f
 
 
 
 
 
 
dc14519
 
4a04968
895a63f
 
4a04968
895a63f
 
 
4a04968
 
 
 
 
 
 
 
 
dc14519
 
 
 
 
 
 
 
 
 
 
 
4a04968
dc14519
 
 
895a63f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a82e45b
 
ad2ecea
 
dc14519
a5e663f
ad2ecea
 
 
 
a82e45b
ad2ecea
a82e45b
 
 
 
ad2ecea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc14519
 
 
 
ad2ecea
 
 
 
da484d7
 
16c2a22
9c71bb7
da484d7
 
9db586c
 
 
 
 
 
 
 
 
 
 
da484d7
 
16c2a22
da484d7
 
 
 
16c2a22
da484d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
import os
import time
import json
import logging
import torch
import re
from typing import Dict, Any, AsyncIterator, Union, List, Optional
import asyncio
from threading import Thread, Lock
from huggingface_hub import login
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList

from app.utils.constants import (
    MODEL_NAME,
    CACHE_DIR,
    FRENCH_SYSTEM_PROMPT,
    EOS_TOKENS,
    PAD_TOKEN_ID,
    DEFAULT_MAX_TOKENS,
    DEFAULT_TEMPERATURE,
    DEFAULT_TOP_P,
    DEFAULT_TOP_K,
    REPETITION_PENALTY,
    MODEL_INIT_TIMEOUT_SECONDS,
    MODEL_INIT_WAIT_INTERVAL_SECONDS,
)
from app.utils.helpers import (
    get_hf_token,
    is_french_request,
    has_french_system_prompt,
    log_info,
    log_warning,
    log_error,
)
from app.utils.memory import clear_gpu_memory
from app.utils.stats import get_stats_tracker, RequestStats

logger = logging.getLogger(__name__)

# Global model state
model = None
tokenizer = None
_init_lock = Lock()
_initializing = False
_initialized = False


def initialize_model(force_reload: bool = False):
    """
    Initialize Transformers model with Qwen3.
    
    Args:
        force_reload: If True, reload model even if already initialized.
    
    Thread-safe initialization with proper memory cleanup on failure.
    Handles authentication with Hugging Face Hub for accessing DragonLLM models.
    """
    global model, tokenizer, _initializing, _initialized
    
    # Check if already initialized (unless force reload)
    if not force_reload and _initialized and model is not None:
        return
    
    with _init_lock:
        # Double-check after acquiring lock
        if not force_reload and _initialized and model is not None:
            return
        
        # Handle concurrent initialization
        if _initializing:
            log_warning("Model initialization already in progress, waiting...")
            wait_count = 0
            while _initializing and wait_count < MODEL_INIT_TIMEOUT_SECONDS:
                time.sleep(MODEL_INIT_WAIT_INTERVAL_SECONDS)
                wait_count += 1
                if _initialized and model is not None:
                    return
            if wait_count >= MODEL_INIT_TIMEOUT_SECONDS:
                log_error("Model initialization timeout!", print_output=True)
                raise RuntimeError("Model initialization timed out")
            return
        
        # Clear previous model if force reloading
        if force_reload and model is not None:
            log_info("Force reload requested, clearing existing model...", print_output=True)
            clear_gpu_memory()
            model = None
            tokenizer = None
            _initialized = False
        
        # Clear any previous failed attempts
        if model is None and torch.cuda.is_available():
            clear_gpu_memory()
        
        _initializing = True
        
        try:
            log_info(f"Initializing Transformers with model: {MODEL_NAME}", print_output=True)
            
            # Get HF token
            hf_token, token_source = get_hf_token()
            
            if hf_token:
                log_info(f"{token_source} found (length: {len(hf_token)})", print_output=True)
                
                # Authenticate with Hugging Face Hub
                # login() automatically handles token precedence and environment variables
                try:
                    login(token=hf_token, add_to_git_credential=False)
                    log_info("Successfully authenticated with Hugging Face Hub", print_output=True)
                except Exception as e:
                    log_warning(f"Failed to authenticate with HF Hub: {e}", print_output=True)
            else:
                log_warning(
                    "No HF token found! Model download may fail if model is gated.",
                    print_output=True
                )
            
            # Load tokenizer
            # Modern transformers (4.45.0+) auto-load chat templates from model repo
            log_info("Loading tokenizer...", print_output=True)
            tokenizer = AutoTokenizer.from_pretrained(
                MODEL_NAME,
                token=hf_token,
                trust_remote_code=True,
                cache_dir=CACHE_DIR,
            )
            
            # Verify chat template is available (should be auto-loaded)
            if not hasattr(tokenizer, 'chat_template') or tokenizer.chat_template is None:
                log_warning("Chat template not found - will use fallback formatting")
            
            log_info("Tokenizer loaded", print_output=True)
            
            # Clear GPU memory before loading model
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                import gc
                gc.collect()
            
            # Load model
            log_info("Loading model (this may take a few minutes)...", print_output=True)
            model = AutoModelForCausalLM.from_pretrained(
                MODEL_NAME,
                token=hf_token,
                trust_remote_code=True,
                dtype=torch.bfloat16,
                device_map="auto",
                max_memory={0: "20GiB"} if torch.cuda.is_available() else None,
                cache_dir=CACHE_DIR,
                low_cpu_mem_usage=True,
            )
            
            model.eval()
            _initialized = True
            
            log_info("Model loaded successfully!", print_output=True)
            
        except Exception as e:
            error_msg = f"Error initializing model: {e}"
            log_error(error_msg, exc_info=True, print_output=True)
            
            clear_gpu_memory()
            model = None
            tokenizer = None
            
            # Provide helpful error message for authentication issues
            if "401" in str(e) or "Unauthorized" in str(e) or "authentication" in str(e).lower():
                print("\nAuthentication Error Detected!")
                print("1. Ensure HF_TOKEN_LC2 is set in your environment")
                print("2. Accept model terms at: https://huggingface.co/DragonLLM/Qwen-Open-Finance-R-8B")
                print("3. Verify token has access to DragonLLM models")
            
            raise
        finally:
            _initializing = False


class TransformersProvider:
    """Provider for Transformers-based model inference."""
    
    def __init__(self):
        pass
    
    async def list_models(self) -> Dict[str, Any]:
        """List available models."""
        return {
            "object": "list",
            "data": [
                {
                    "id": MODEL_NAME,
                    "object": "model",
                    "created": 1677610602,
                    "owned_by": "DragonLLM",
                    "permission": [],
                    "root": MODEL_NAME,
                    "parent": None,
                }
            ]
        }
    
    async def chat(
        self, payload: Dict[str, Any], stream: bool = False
    ) -> Union[Dict[str, Any], AsyncIterator[str]]:
        """Handle chat completion requests."""
        try:
            # Initialize model on first use (thread-safe check)
            if not is_model_ready():
                log_info("Model not initialized, initializing now...")
                initialize_model()
                log_info("Model initialized successfully")
            
            messages = payload.get("messages", [])
            temperature = payload.get("temperature", DEFAULT_TEMPERATURE)
            max_tokens = payload.get("max_tokens", DEFAULT_MAX_TOKENS)
            top_p = payload.get("top_p", DEFAULT_TOP_P)
            tools = payload.get("tools", None)  # βœ… Extract tools
            tool_choice = payload.get("tool_choice", "auto")  # βœ… Extract tool_choice
            response_format = payload.get("response_format", None)  # βœ… Extract response_format
            
            # Handle tool_choice="required" - treat as "auto" for text-based tool calls
            if tool_choice == "required":
                tool_choice = "auto"
                log_info("tool_choice='required' converted to 'auto' for text-based tool calls")
            
            # Detect French and add system prompt if needed
            if is_french_request(messages) and not has_french_system_prompt(messages):
                messages = [{"role": "system", "content": FRENCH_SYSTEM_PROMPT}] + messages
            
            # βœ… Handle response_format for structured JSON outputs
            json_output_required = False
            if response_format:
                if isinstance(response_format, dict):
                    json_output_required = response_format.get("type") == "json_object"
                elif hasattr(response_format, "type"):
                    json_output_required = response_format.type == "json_object"
            
            # βœ… Add tools to system prompt if provided
            if tools:
                tools_description = self._format_tools_for_prompt(tools)
                # Add tools to the last system message or create a new one
                system_messages = [msg for msg in messages if msg.get("role") == "system"]
                if system_messages:
                    # Append to existing system message
                    last_system = system_messages[-1]
                    last_system["content"] = f"{last_system['content']}\n\n{tools_description}"
                else:
                    # Add new system message with tools
                    messages = [{"role": "system", "content": tools_description}] + messages
                log_info(f"Tools added to prompt: {len(tools)} tools")
            
            # βœ… Add JSON output requirement to system prompt if response_format requires it
            if json_output_required:
                json_instruction = (
                    "\n\nCRITICAL: response_format is set to json_object. You MUST respond with ONLY valid JSON. "
                    "NO <think> tags, NO reasoning, NO explanations, NO text before or after the JSON. "
                    "Start your response directly with { and end with }. "
                    "\n\nEXAMPLES:\n"
                    "If asked for a random number 1-10:\n"
                    "CORRECT: {\"nombre\": 7}\n"
                    "WRONG: <think>I need to generate...</think>{\"nombre\": 7}\n"
                    "WRONG: Here is the JSON: {\"nombre\": 7}\n"
                    "\nIf asked for portfolio data:\n"
                    "CORRECT: {\"positions\": [{\"symbole\": \"AIR.PA\", \"quantite\": 50}]}\n"
                    "WRONG: <think>Let me extract...</think>{\"positions\": [...]}\n"
                    "\nREMEMBER: Your response must be ONLY the JSON object, nothing else. Do not use <think> tags."
                )
                system_messages = [msg for msg in messages if msg.get("role") == "system"]
                if system_messages:
                    last_system = system_messages[-1]
                    last_system["content"] = f"{last_system['content']}{json_instruction}"
                else:
                    messages = [{"role": "system", "content": json_instruction}] + messages
                log_info("JSON output format enforced via system prompt")
            
            # Generate prompt using chat template
            if hasattr(tokenizer, "apply_chat_template"):
                prompt = tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True,
                )
                log_info(f"Chat template applied. Messages: {len(messages)}")
                if any(msg.get("role") == "system" for msg in messages):
                    system_msg = next(msg for msg in messages if msg.get("role") == "system")
                    log_info(f"System message present: {system_msg['content'][:100]}...")
            else:
                prompt = self._messages_to_prompt(messages)
                log_warning("No chat_template found, using fallback")
            
            # Tokenize
            # Move inputs to model device (device_map="auto" handles model placement, but inputs need explicit device placement)
            inputs = tokenizer(prompt, return_tensors="pt")
            # Get model device (works with device_map="auto" by checking first parameter's device)
            model_device = next(model.parameters()).device
            inputs = {k: v.to(model_device) for k, v in inputs.items()}
            
            # Handle streaming vs non-streaming
            if stream:
                return self._chat_stream(inputs, temperature, top_p, max_tokens, payload.get("model", MODEL_NAME), tools, json_output_required)
            
            return self._generate_response(inputs, temperature, top_p, max_tokens, payload.get("model", MODEL_NAME), tools, json_output_required)
            
        except Exception as e:
            log_error(f"Error in chat completion: {str(e)}", exc_info=True)
            raise
    
    def _generate_response(
        self, inputs, temperature: float, top_p: float, max_tokens: int, model_id: str, tools: Optional[List[Dict[str, Any]]] = None, json_output_required: bool = False
    ) -> Dict[str, Any]:
        """Generate non-streaming response."""
        # Prepare generation kwargs
        generation_kwargs = {
            "max_new_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "top_k": DEFAULT_TOP_K,
            "do_sample": temperature > 0,
            "pad_token_id": PAD_TOKEN_ID,
            "eos_token_id": EOS_TOKENS,
            "repetition_penalty": REPETITION_PENALTY,
            "early_stopping": False,
            "use_cache": True,
        }
        
        # Note: Qwen reasoning models are designed to use reasoning tags
        # We cannot completely disable reasoning, but we can:
        # 1. Use strong prompts (already done above)
        # 2. Post-process to extract desired output (done in _extract_json_from_text and _parse_tool_calls)
        # 3. Set temperature to 0 for completely deterministic JSON output
        #    Temperature=0 uses greedy decoding (always picks most likely token)
        #    This maximizes consistency for structured outputs
        if json_output_required:
            # Set temperature to 0 for completely deterministic JSON output
            # This uses greedy decoding which is ideal for structured formats
            original_temp = generation_kwargs["temperature"]
            generation_kwargs["temperature"] = 0.0
            generation_kwargs["do_sample"] = False  # Explicitly set for temperature=0
            log_info(f"Set temperature from {original_temp} to 0.0 (greedy decoding) for JSON output format")
        
        # Ensure inputs are on model device before generation
        model_device = next(model.parameters()).device
        inputs = {k: v.to(model_device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                **generation_kwargs,
            )
        
        # Extract token counts using tokenizer for accuracy
        # Count prompt tokens (more accurate than shape[1] as it handles special tokens correctly)
        prompt_tokens = len(inputs.input_ids[0])
        generated_ids = outputs[0][inputs.input_ids.shape[1]:]
        generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
        completion_tokens = len(generated_ids)
        
        # βœ… If JSON output is required, try to extract JSON from the response
        if json_output_required:
            generated_text = self._extract_json_from_text(generated_text)
        
        # βœ… Parse tool calls from generated text
        tool_calls = None
        if tools:
            tool_calls = self._parse_tool_calls(generated_text, tools)
            if tool_calls:
                log_info(f"Parsed {len(tool_calls)} tool calls from response")
                # Remove tool call markers from content if present
                generated_text = self._clean_tool_calls_from_text(generated_text)
        
        finish_reason = "tool_calls" if tool_calls else ("length" if completion_tokens >= max_tokens else "stop")
        
        log_info(f"Generated {completion_tokens} tokens (max: {max_tokens}), finish: {finish_reason}")
        
        # Record statistics
        stats_tracker = get_stats_tracker()
        stats_tracker.record_request(RequestStats(
            timestamp=time.time(),
            prompt_tokens=prompt_tokens,
            completion_tokens=completion_tokens,
            total_tokens=prompt_tokens + completion_tokens,
            model=model_id,
            finish_reason=finish_reason,
        ))
        
        # Build message with optional tool_calls
        message = {"role": "assistant", "content": generated_text if generated_text.strip() else None}
        if tool_calls:
            message["tool_calls"] = tool_calls
        
        return {
            "id": f"chatcmpl-{os.urandom(12).hex()}",
            "object": "chat.completion",
            "created": int(time.time()),
            "model": model_id,
            "choices": [
                {
                    "index": 0,
                    "message": message,
                    "finish_reason": finish_reason,
                }
            ],
            "usage": {
                "prompt_tokens": prompt_tokens,
                "completion_tokens": completion_tokens,
                "total_tokens": prompt_tokens + completion_tokens,
            },
        }
    
    async def _chat_stream(
        self, inputs, temperature: float, top_p: float, max_tokens: int, model_id: str, tools: Optional[List[Dict[str, Any]]] = None, json_output_required: bool = False
    ) -> AsyncIterator[str]:
        """Stream chat completions."""
        completion_id = f"chatcmpl-{os.urandom(12).hex()}"
        created = int(time.time())
        
        # Count prompt tokens
        prompt_tokens = len(inputs.input_ids[0])
        completion_tokens = 0
        generated_text = ""
        
        streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
        
        generation_kwargs = {
            "max_new_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "do_sample": temperature > 0,
            "pad_token_id": tokenizer.eos_token_id,
            "eos_token_id": tokenizer.eos_token_id,
            "min_new_tokens": min(10, max_tokens // 10),
            "repetition_penalty": REPETITION_PENALTY,
            "streamer": streamer,
        }
        
        def generate():
            # Ensure inputs are on model device before generation
            model_device = next(model.parameters()).device
            inputs_on_device = {k: v.to(model_device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
            with torch.no_grad():
                model.generate(**inputs_on_device, **generation_kwargs)
        
        generation_thread = Thread(target=generate)
        generation_thread.start()
        
        try:
            for token in streamer:
                generated_text += token
                chunk = {
                    "id": completion_id,
                    "object": "chat.completion.chunk",
                    "created": created,
                    "model": model_id,
                    "choices": [
                        {
                            "index": 0,
                            "delta": {"content": token},
                            "finish_reason": None,
                        }
                    ],
                }
                yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
                await asyncio.sleep(0)
        finally:
            generation_thread.join()
            
            # Count completion tokens accurately from generated text
            if generated_text:
                # Use tokenizer to count tokens accurately
                completion_tokens = len(tokenizer.encode(generated_text, add_special_tokens=False))
            else:
                completion_tokens = 0
            
            # Record statistics for streaming request
            stats_tracker = get_stats_tracker()
            finish_reason = "length" if completion_tokens >= max_tokens else "stop"
            stats_tracker.record_request(RequestStats(
                timestamp=time.time(),
                prompt_tokens=prompt_tokens,
                completion_tokens=completion_tokens,
                total_tokens=prompt_tokens + completion_tokens,
                model=model_id,
                finish_reason=finish_reason,
            ))
        
        # Send final chunk
        final_chunk = {
            "id": completion_id,
            "object": "chat.completion.chunk",
            "created": created,
            "model": model_id,
            "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
        }
        yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
        yield "data: [DONE]\n\n"
    
    def _messages_to_prompt(self, messages: list) -> str:
        """Convert OpenAI messages format to prompt (fallback)."""
        prompt = ""
        for message in messages:
            role = message["role"]
            content = message["content"]
            if role == "system":
                prompt += f"System: {content}\n"
            elif role == "user":
                prompt += f"User: {content}\n"
            elif role == "assistant":
                prompt += f"Assistant: {content}\n"
        prompt += "Assistant: "
        return prompt
    
    def _remove_reasoning_tags(self, text: str) -> str:
        """Remove Qwen reasoning tags from text."""
        # Remove reasoning tags - matches <think>...</think>
        cleaned_text = re.sub(
            r'<think>.*?</think>',
            '',
            text,
            flags=re.DOTALL | re.IGNORECASE
        )
        
        # Handle unclosed reasoning tags (split on closing tag)
        if "</think>" in cleaned_text:
            parts = cleaned_text.split("</think>", 1)
            if len(parts) > 1:
                cleaned_text = parts[1].strip()
        
        # If still has opening tag but no closing, remove everything before first {
        if "<think>" in cleaned_text.lower() and "{" in cleaned_text:
            brace_pos = cleaned_text.find('{')
            if brace_pos != -1:
                cleaned_text = cleaned_text[brace_pos:]
        
        return cleaned_text
    
def _extract_json_by_brace_matching(self, text: str, start_pos: int = 0) -> Optional[str]:
    """Extract JSON object by matching braces starting at given position."""
    brace_start = text.find('{', start_pos)
    if brace_start == -1:
        return None
    
    brace_count = 0
    in_string = False
    escape_next = False
    for i in range(brace_start, len(text)):
        if escape_next:
            escape_next = False
            continue
        if text[i] == '\\':
            escape_next = True
        elif text[i] == '"' and not in_string:
            in_string = True
        elif text[i] == '"' and in_string:
            in_string = False
        elif text[i] == '{' and not in_string:
            brace_count += 1
        elif text[i] == '}' and not in_string:
            brace_count -= 1
            if brace_count == 0:
                json_candidate = text[brace_start:i+1]
                try:
                    json.loads(json_candidate)
                    return json_candidate
                except json.JSONDecodeError:
                    return None
    return None
    
    def _format_tools_for_prompt(self, tools: List[Dict[str, Any]]) -> str:
        """Format tools for inclusion in system prompt."""
        tools_text = (
            "CRITICAL: You have access to the following tools. When you need to use a tool, "
            "you MUST respond ONLY with the tool call format below. NO reasoning tags, NO explanations, "
            "ONLY the tool call format.\n\n"
        )
        
        for i, tool in enumerate(tools, 1):
            func = tool.get("function", {})
            name = func.get("name", "")
            description = func.get("description", "")
            parameters = func.get("parameters", {})
            
            tools_text += f"Tool {i}: {name}\n"
            if description:
                tools_text += f"Description: {description}\n"
            if parameters:
                tools_text += f"Parameters: {json.dumps(parameters, ensure_ascii=False, indent=2)}\n"
            tools_text += "\n"
        
        tools_text += (
            "TO USE A TOOL, respond EXACTLY in this format (NO reasoning, NO text before or after):\n"
            "<tool_call>\n"
            '{"name": "tool_name", "arguments": {"param1": "value1", "param2": "value2"}}\n'
            "</tool_call>\n\n"
            "EXAMPLE 1 - If asked to calculate future value:\n"
            "<tool_call>\n"
            '{"name": "calculer_valeur_future", "arguments": {"capital_initial": 10000, "taux": 0.05, "duree": 10}}\n'
            "</tool_call>\n\n"
            "EXAMPLE 2 - If asked to get stock price:\n"
            "<tool_call>\n"
            '{"name": "obtenir_prix_action", "arguments": {"symbole": "AIR.PA"}}\n'
            "</tool_call>\n\n"
            "IMPORTANT: Start your response directly with <tool_call>. Do NOT include <think> tags or any reasoning. "
            "The tool call format is the ONLY thing you should output when using a tool."
        )
        
        return tools_text
    
    def _parse_tool_calls(self, generated_text: str, tools: List[Dict[str, Any]]) -> Optional[List[Dict[str, Any]]]:
        """Parse tool calls from generated text."""
        tool_calls = []
        
        # Remove reasoning tags to get clean text
        cleaned_text = self._remove_reasoning_tags(generated_text)
        
        # Pattern to match <tool_call>...</tool_call> blocks
        pattern = r'<tool_call>\s*({.*?})\s*</tool_call>'
        matches = re.findall(pattern, cleaned_text, re.DOTALL)
        
        # Also try to match JSON objects that look like tool calls
        if not matches:
            # Try to find JSON objects with "name" and "arguments" keys (more flexible pattern)
            # This handles cases where model generates JSON but not wrapped in tags
            json_pattern = r'\{\s*"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{[^}]+\}\s*\}'
            matches = re.findall(json_pattern, cleaned_text, re.DOTALL)
        
        # If still no matches, try to find any JSON object with "name" field that matches a tool name
        if not matches:
            tool_names = [t.get("function", {}).get("name", "") for t in tools]
            # Look for JSON objects that might be tool calls
            brace_start = 0
            while True:
                json_candidate = self._extract_json_by_brace_matching(cleaned_text, brace_start)
                if json_candidate is None:
                    break
                try:
                    candidate_data = json.loads(json_candidate)
                    if "name" in candidate_data and candidate_data["name"] in tool_names:
                        matches.append(json_candidate)
                        break
                except json.JSONDecodeError:
                    pass
                # Find next {
                brace_start = cleaned_text.find('{', cleaned_text.find(json_candidate) + len(json_candidate))
                if brace_start == -1:
                    break
        
        for i, match in enumerate(matches):
            try:
                call_data = json.loads(match)
                name = call_data.get("name", "")
                arguments = call_data.get("arguments", {})
                
                # Validate tool name exists in provided tools
                tool_names = [t.get("function", {}).get("name", "") for t in tools]
                if name not in tool_names:
                    log_warning(f"Tool call to unknown tool: {name}")
                    continue
                
                # Ensure arguments is a JSON string
                if isinstance(arguments, dict):
                    arguments_str = json.dumps(arguments, ensure_ascii=False)
                else:
                    arguments_str = str(arguments)
                
                tool_calls.append({
                    "id": f"call_{os.urandom(8).hex()}",
                    "type": "function",
                    "function": {
                        "name": name,
                        "arguments": arguments_str
                    }
                })
            except json.JSONDecodeError as e:
                log_warning(f"Failed to parse tool call JSON: {e}, match: {match[:100]}")
                continue
            except Exception as e:
                log_warning(f"Error parsing tool call: {e}")
                continue
        
        return tool_calls if tool_calls else None
    
    def _clean_tool_calls_from_text(self, text: str) -> str:
        """Remove tool call markers from text to return clean content."""
        # Remove <tool_call>...</tool_call> blocks
        text = re.sub(r'<tool_call>.*?</tool_call>', '', text, flags=re.DOTALL)
        # Clean up extra whitespace
        text = re.sub(r'\n\s*\n', '\n\n', text)
        return text.strip()
    
    def _extract_json_from_text(self, text: str) -> str:
        """Extract JSON from text, handling cases where JSON is wrapped in markdown, reasoning tags, or other text."""
        # Step 1: Remove reasoning tags first (Qwen reasoning models)
        cleaned_text = self._remove_reasoning_tags(text)
        
        # Step 2: Try to find JSON wrapped in markdown code blocks
        json_code_block = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', cleaned_text, re.DOTALL)
        if json_code_block:
            json_str = json_code_block.group(1).strip()
            try:
                json.loads(json_str)  # Validate
                return json_str
            except json.JSONDecodeError:
                pass
        
        # Step 3: Find JSON object(s) in the text
        # Use a more robust approach: find all { ... } patterns and validate them
        json_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
        matches = re.finditer(json_pattern, cleaned_text, re.DOTALL)
        
        # Try to find the largest valid JSON object
        best_match = None
        best_length = 0
        
        for match in matches:
            json_candidate = match.group(0)
            try:
                json.loads(json_candidate)  # Validate
                if len(json_candidate) > best_length:
                    best_match = json_candidate
                    best_length = len(json_candidate)
            except json.JSONDecodeError:
                continue
        
        if best_match:
            return best_match.strip()
        
        # Step 4: Fallback - try to find any JSON-like structure using brace matching
        json_candidate = self._extract_json_by_brace_matching(cleaned_text)
        if json_candidate:
            return json_candidate.strip()
        
        # Step 5: If no JSON found, return cleaned text (without reasoning tags)
        # This allows the caller to handle it or show an error
        return cleaned_text.strip()


# Module-level provider instance
_provider = TransformersProvider()


def is_model_ready() -> bool:
    """
    Thread-safe check if the model is loaded and ready for inference.
    
    Returns:
        True if model is initialized and loaded, False otherwise.
    """
    with _init_lock:
        return _initialized and model is not None and tokenizer is not None


# Module-level functions for direct import
async def list_models() -> Dict[str, Any]:
    """List available models."""
    return await _provider.list_models()


async def chat(payload: Dict[str, Any], stream: bool = False) -> Union[Dict[str, Any], AsyncIterator[str]]:
    """Chat completion."""
    return await _provider.chat(payload, stream=stream)