|
|
import os |
|
|
import time |
|
|
import json |
|
|
import logging |
|
|
import torch |
|
|
import re |
|
|
from typing import Dict, Any, AsyncIterator, Union, List, Optional |
|
|
import asyncio |
|
|
from threading import Thread, Lock |
|
|
from huggingface_hub import login |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList |
|
|
|
|
|
from app.utils.constants import ( |
|
|
MODEL_NAME, |
|
|
CACHE_DIR, |
|
|
FRENCH_SYSTEM_PROMPT, |
|
|
EOS_TOKENS, |
|
|
PAD_TOKEN_ID, |
|
|
DEFAULT_MAX_TOKENS, |
|
|
DEFAULT_TEMPERATURE, |
|
|
DEFAULT_TOP_P, |
|
|
DEFAULT_TOP_K, |
|
|
REPETITION_PENALTY, |
|
|
MODEL_INIT_TIMEOUT_SECONDS, |
|
|
MODEL_INIT_WAIT_INTERVAL_SECONDS, |
|
|
) |
|
|
from app.utils.helpers import ( |
|
|
get_hf_token, |
|
|
is_french_request, |
|
|
has_french_system_prompt, |
|
|
log_info, |
|
|
log_warning, |
|
|
log_error, |
|
|
) |
|
|
from app.utils.memory import clear_gpu_memory |
|
|
from app.utils.stats import get_stats_tracker, RequestStats |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
model = None |
|
|
tokenizer = None |
|
|
_init_lock = Lock() |
|
|
_initializing = False |
|
|
_initialized = False |
|
|
|
|
|
|
|
|
def initialize_model(force_reload: bool = False): |
|
|
""" |
|
|
Initialize Transformers model with Qwen3. |
|
|
|
|
|
Args: |
|
|
force_reload: If True, reload model even if already initialized. |
|
|
|
|
|
Thread-safe initialization with proper memory cleanup on failure. |
|
|
Handles authentication with Hugging Face Hub for accessing DragonLLM models. |
|
|
""" |
|
|
global model, tokenizer, _initializing, _initialized |
|
|
|
|
|
|
|
|
if not force_reload and _initialized and model is not None: |
|
|
return |
|
|
|
|
|
with _init_lock: |
|
|
|
|
|
if not force_reload and _initialized and model is not None: |
|
|
return |
|
|
|
|
|
|
|
|
if _initializing: |
|
|
log_warning("Model initialization already in progress, waiting...") |
|
|
wait_count = 0 |
|
|
while _initializing and wait_count < MODEL_INIT_TIMEOUT_SECONDS: |
|
|
time.sleep(MODEL_INIT_WAIT_INTERVAL_SECONDS) |
|
|
wait_count += 1 |
|
|
if _initialized and model is not None: |
|
|
return |
|
|
if wait_count >= MODEL_INIT_TIMEOUT_SECONDS: |
|
|
log_error("Model initialization timeout!", print_output=True) |
|
|
raise RuntimeError("Model initialization timed out") |
|
|
return |
|
|
|
|
|
|
|
|
if force_reload and model is not None: |
|
|
log_info("Force reload requested, clearing existing model...", print_output=True) |
|
|
clear_gpu_memory() |
|
|
model = None |
|
|
tokenizer = None |
|
|
_initialized = False |
|
|
|
|
|
|
|
|
if model is None and torch.cuda.is_available(): |
|
|
clear_gpu_memory() |
|
|
|
|
|
_initializing = True |
|
|
|
|
|
try: |
|
|
log_info(f"Initializing Transformers with model: {MODEL_NAME}", print_output=True) |
|
|
|
|
|
|
|
|
hf_token, token_source = get_hf_token() |
|
|
|
|
|
if hf_token: |
|
|
log_info(f"{token_source} found (length: {len(hf_token)})", print_output=True) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
login(token=hf_token, add_to_git_credential=False) |
|
|
log_info("Successfully authenticated with Hugging Face Hub", print_output=True) |
|
|
except Exception as e: |
|
|
log_warning(f"Failed to authenticate with HF Hub: {e}", print_output=True) |
|
|
else: |
|
|
log_warning( |
|
|
"No HF token found! Model download may fail if model is gated.", |
|
|
print_output=True |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
log_info("Loading tokenizer...", print_output=True) |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
MODEL_NAME, |
|
|
token=hf_token, |
|
|
trust_remote_code=True, |
|
|
cache_dir=CACHE_DIR, |
|
|
) |
|
|
|
|
|
|
|
|
if not hasattr(tokenizer, 'chat_template') or tokenizer.chat_template is None: |
|
|
log_warning("Chat template not found - will use fallback formatting") |
|
|
|
|
|
log_info("Tokenizer loaded", print_output=True) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
import gc |
|
|
gc.collect() |
|
|
|
|
|
|
|
|
log_info("Loading model (this may take a few minutes)...", print_output=True) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_NAME, |
|
|
token=hf_token, |
|
|
trust_remote_code=True, |
|
|
dtype=torch.bfloat16, |
|
|
device_map="auto", |
|
|
max_memory={0: "20GiB"} if torch.cuda.is_available() else None, |
|
|
cache_dir=CACHE_DIR, |
|
|
low_cpu_mem_usage=True, |
|
|
) |
|
|
|
|
|
model.eval() |
|
|
_initialized = True |
|
|
|
|
|
log_info("Model loaded successfully!", print_output=True) |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error initializing model: {e}" |
|
|
log_error(error_msg, exc_info=True, print_output=True) |
|
|
|
|
|
clear_gpu_memory() |
|
|
model = None |
|
|
tokenizer = None |
|
|
|
|
|
|
|
|
if "401" in str(e) or "Unauthorized" in str(e) or "authentication" in str(e).lower(): |
|
|
print("\nAuthentication Error Detected!") |
|
|
print("1. Ensure HF_TOKEN_LC2 is set in your environment") |
|
|
print("2. Accept model terms at: https://huggingface.co/DragonLLM/Qwen-Open-Finance-R-8B") |
|
|
print("3. Verify token has access to DragonLLM models") |
|
|
|
|
|
raise |
|
|
finally: |
|
|
_initializing = False |
|
|
|
|
|
|
|
|
class TransformersProvider: |
|
|
"""Provider for Transformers-based model inference.""" |
|
|
|
|
|
def __init__(self): |
|
|
pass |
|
|
|
|
|
async def list_models(self) -> Dict[str, Any]: |
|
|
"""List available models.""" |
|
|
return { |
|
|
"object": "list", |
|
|
"data": [ |
|
|
{ |
|
|
"id": MODEL_NAME, |
|
|
"object": "model", |
|
|
"created": 1677610602, |
|
|
"owned_by": "DragonLLM", |
|
|
"permission": [], |
|
|
"root": MODEL_NAME, |
|
|
"parent": None, |
|
|
} |
|
|
] |
|
|
} |
|
|
|
|
|
async def chat( |
|
|
self, payload: Dict[str, Any], stream: bool = False |
|
|
) -> Union[Dict[str, Any], AsyncIterator[str]]: |
|
|
"""Handle chat completion requests.""" |
|
|
try: |
|
|
|
|
|
if not is_model_ready(): |
|
|
log_info("Model not initialized, initializing now...") |
|
|
initialize_model() |
|
|
log_info("Model initialized successfully") |
|
|
|
|
|
messages = payload.get("messages", []) |
|
|
temperature = payload.get("temperature", DEFAULT_TEMPERATURE) |
|
|
max_tokens = payload.get("max_tokens", DEFAULT_MAX_TOKENS) |
|
|
top_p = payload.get("top_p", DEFAULT_TOP_P) |
|
|
tools = payload.get("tools", None) |
|
|
tool_choice = payload.get("tool_choice", "auto") |
|
|
response_format = payload.get("response_format", None) |
|
|
|
|
|
|
|
|
if tool_choice == "required": |
|
|
tool_choice = "auto" |
|
|
log_info("tool_choice='required' converted to 'auto' for text-based tool calls") |
|
|
|
|
|
|
|
|
if is_french_request(messages) and not has_french_system_prompt(messages): |
|
|
messages = [{"role": "system", "content": FRENCH_SYSTEM_PROMPT}] + messages |
|
|
|
|
|
|
|
|
json_output_required = False |
|
|
if response_format: |
|
|
if isinstance(response_format, dict): |
|
|
json_output_required = response_format.get("type") == "json_object" |
|
|
elif hasattr(response_format, "type"): |
|
|
json_output_required = response_format.type == "json_object" |
|
|
|
|
|
|
|
|
if tools: |
|
|
tools_description = self._format_tools_for_prompt(tools) |
|
|
|
|
|
system_messages = [msg for msg in messages if msg.get("role") == "system"] |
|
|
if system_messages: |
|
|
|
|
|
last_system = system_messages[-1] |
|
|
last_system["content"] = f"{last_system['content']}\n\n{tools_description}" |
|
|
else: |
|
|
|
|
|
messages = [{"role": "system", "content": tools_description}] + messages |
|
|
log_info(f"Tools added to prompt: {len(tools)} tools") |
|
|
|
|
|
|
|
|
if json_output_required: |
|
|
json_instruction = ( |
|
|
"\n\nCRITICAL: response_format is set to json_object. You MUST respond with ONLY valid JSON. " |
|
|
"NO <think> tags, NO reasoning, NO explanations, NO text before or after the JSON. " |
|
|
"Start your response directly with { and end with }. " |
|
|
"\n\nEXAMPLES:\n" |
|
|
"If asked for a random number 1-10:\n" |
|
|
"CORRECT: {\"nombre\": 7}\n" |
|
|
"WRONG: <think>I need to generate...</think>{\"nombre\": 7}\n" |
|
|
"WRONG: Here is the JSON: {\"nombre\": 7}\n" |
|
|
"\nIf asked for portfolio data:\n" |
|
|
"CORRECT: {\"positions\": [{\"symbole\": \"AIR.PA\", \"quantite\": 50}]}\n" |
|
|
"WRONG: <think>Let me extract...</think>{\"positions\": [...]}\n" |
|
|
"\nREMEMBER: Your response must be ONLY the JSON object, nothing else. Do not use <think> tags." |
|
|
) |
|
|
system_messages = [msg for msg in messages if msg.get("role") == "system"] |
|
|
if system_messages: |
|
|
last_system = system_messages[-1] |
|
|
last_system["content"] = f"{last_system['content']}{json_instruction}" |
|
|
else: |
|
|
messages = [{"role": "system", "content": json_instruction}] + messages |
|
|
log_info("JSON output format enforced via system prompt") |
|
|
|
|
|
|
|
|
if hasattr(tokenizer, "apply_chat_template"): |
|
|
prompt = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
) |
|
|
log_info(f"Chat template applied. Messages: {len(messages)}") |
|
|
if any(msg.get("role") == "system" for msg in messages): |
|
|
system_msg = next(msg for msg in messages if msg.get("role") == "system") |
|
|
log_info(f"System message present: {system_msg['content'][:100]}...") |
|
|
else: |
|
|
prompt = self._messages_to_prompt(messages) |
|
|
log_warning("No chat_template found, using fallback") |
|
|
|
|
|
|
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
|
|
model_device = next(model.parameters()).device |
|
|
inputs = {k: v.to(model_device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
if stream: |
|
|
return self._chat_stream(inputs, temperature, top_p, max_tokens, payload.get("model", MODEL_NAME), tools, json_output_required) |
|
|
|
|
|
return self._generate_response(inputs, temperature, top_p, max_tokens, payload.get("model", MODEL_NAME), tools, json_output_required) |
|
|
|
|
|
except Exception as e: |
|
|
log_error(f"Error in chat completion: {str(e)}", exc_info=True) |
|
|
raise |
|
|
|
|
|
def _generate_response( |
|
|
self, inputs, temperature: float, top_p: float, max_tokens: int, model_id: str, tools: Optional[List[Dict[str, Any]]] = None, json_output_required: bool = False |
|
|
) -> Dict[str, Any]: |
|
|
"""Generate non-streaming response.""" |
|
|
|
|
|
generation_kwargs = { |
|
|
"max_new_tokens": max_tokens, |
|
|
"temperature": temperature, |
|
|
"top_p": top_p, |
|
|
"top_k": DEFAULT_TOP_K, |
|
|
"do_sample": temperature > 0, |
|
|
"pad_token_id": PAD_TOKEN_ID, |
|
|
"eos_token_id": EOS_TOKENS, |
|
|
"repetition_penalty": REPETITION_PENALTY, |
|
|
"early_stopping": False, |
|
|
"use_cache": True, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if json_output_required: |
|
|
|
|
|
|
|
|
original_temp = generation_kwargs["temperature"] |
|
|
generation_kwargs["temperature"] = 0.0 |
|
|
generation_kwargs["do_sample"] = False |
|
|
log_info(f"Set temperature from {original_temp} to 0.0 (greedy decoding) for JSON output format") |
|
|
|
|
|
|
|
|
model_device = next(model.parameters()).device |
|
|
inputs = {k: v.to(model_device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
**generation_kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
prompt_tokens = len(inputs.input_ids[0]) |
|
|
generated_ids = outputs[0][inputs.input_ids.shape[1]:] |
|
|
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True) |
|
|
completion_tokens = len(generated_ids) |
|
|
|
|
|
|
|
|
if json_output_required: |
|
|
generated_text = self._extract_json_from_text(generated_text) |
|
|
|
|
|
|
|
|
tool_calls = None |
|
|
if tools: |
|
|
tool_calls = self._parse_tool_calls(generated_text, tools) |
|
|
if tool_calls: |
|
|
log_info(f"Parsed {len(tool_calls)} tool calls from response") |
|
|
|
|
|
generated_text = self._clean_tool_calls_from_text(generated_text) |
|
|
|
|
|
finish_reason = "tool_calls" if tool_calls else ("length" if completion_tokens >= max_tokens else "stop") |
|
|
|
|
|
log_info(f"Generated {completion_tokens} tokens (max: {max_tokens}), finish: {finish_reason}") |
|
|
|
|
|
|
|
|
stats_tracker = get_stats_tracker() |
|
|
stats_tracker.record_request(RequestStats( |
|
|
timestamp=time.time(), |
|
|
prompt_tokens=prompt_tokens, |
|
|
completion_tokens=completion_tokens, |
|
|
total_tokens=prompt_tokens + completion_tokens, |
|
|
model=model_id, |
|
|
finish_reason=finish_reason, |
|
|
)) |
|
|
|
|
|
|
|
|
message = {"role": "assistant", "content": generated_text if generated_text.strip() else None} |
|
|
if tool_calls: |
|
|
message["tool_calls"] = tool_calls |
|
|
|
|
|
return { |
|
|
"id": f"chatcmpl-{os.urandom(12).hex()}", |
|
|
"object": "chat.completion", |
|
|
"created": int(time.time()), |
|
|
"model": model_id, |
|
|
"choices": [ |
|
|
{ |
|
|
"index": 0, |
|
|
"message": message, |
|
|
"finish_reason": finish_reason, |
|
|
} |
|
|
], |
|
|
"usage": { |
|
|
"prompt_tokens": prompt_tokens, |
|
|
"completion_tokens": completion_tokens, |
|
|
"total_tokens": prompt_tokens + completion_tokens, |
|
|
}, |
|
|
} |
|
|
|
|
|
async def _chat_stream( |
|
|
self, inputs, temperature: float, top_p: float, max_tokens: int, model_id: str, tools: Optional[List[Dict[str, Any]]] = None, json_output_required: bool = False |
|
|
) -> AsyncIterator[str]: |
|
|
"""Stream chat completions.""" |
|
|
completion_id = f"chatcmpl-{os.urandom(12).hex()}" |
|
|
created = int(time.time()) |
|
|
|
|
|
|
|
|
prompt_tokens = len(inputs.input_ids[0]) |
|
|
completion_tokens = 0 |
|
|
generated_text = "" |
|
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
generation_kwargs = { |
|
|
"max_new_tokens": max_tokens, |
|
|
"temperature": temperature, |
|
|
"top_p": top_p, |
|
|
"do_sample": temperature > 0, |
|
|
"pad_token_id": tokenizer.eos_token_id, |
|
|
"eos_token_id": tokenizer.eos_token_id, |
|
|
"min_new_tokens": min(10, max_tokens // 10), |
|
|
"repetition_penalty": REPETITION_PENALTY, |
|
|
"streamer": streamer, |
|
|
} |
|
|
|
|
|
def generate(): |
|
|
|
|
|
model_device = next(model.parameters()).device |
|
|
inputs_on_device = {k: v.to(model_device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} |
|
|
with torch.no_grad(): |
|
|
model.generate(**inputs_on_device, **generation_kwargs) |
|
|
|
|
|
generation_thread = Thread(target=generate) |
|
|
generation_thread.start() |
|
|
|
|
|
try: |
|
|
for token in streamer: |
|
|
generated_text += token |
|
|
chunk = { |
|
|
"id": completion_id, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": created, |
|
|
"model": model_id, |
|
|
"choices": [ |
|
|
{ |
|
|
"index": 0, |
|
|
"delta": {"content": token}, |
|
|
"finish_reason": None, |
|
|
} |
|
|
], |
|
|
} |
|
|
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" |
|
|
await asyncio.sleep(0) |
|
|
finally: |
|
|
generation_thread.join() |
|
|
|
|
|
|
|
|
if generated_text: |
|
|
|
|
|
completion_tokens = len(tokenizer.encode(generated_text, add_special_tokens=False)) |
|
|
else: |
|
|
completion_tokens = 0 |
|
|
|
|
|
|
|
|
stats_tracker = get_stats_tracker() |
|
|
finish_reason = "length" if completion_tokens >= max_tokens else "stop" |
|
|
stats_tracker.record_request(RequestStats( |
|
|
timestamp=time.time(), |
|
|
prompt_tokens=prompt_tokens, |
|
|
completion_tokens=completion_tokens, |
|
|
total_tokens=prompt_tokens + completion_tokens, |
|
|
model=model_id, |
|
|
finish_reason=finish_reason, |
|
|
)) |
|
|
|
|
|
|
|
|
final_chunk = { |
|
|
"id": completion_id, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": created, |
|
|
"model": model_id, |
|
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], |
|
|
} |
|
|
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n" |
|
|
yield "data: [DONE]\n\n" |
|
|
|
|
|
def _messages_to_prompt(self, messages: list) -> str: |
|
|
"""Convert OpenAI messages format to prompt (fallback).""" |
|
|
prompt = "" |
|
|
for message in messages: |
|
|
role = message["role"] |
|
|
content = message["content"] |
|
|
if role == "system": |
|
|
prompt += f"System: {content}\n" |
|
|
elif role == "user": |
|
|
prompt += f"User: {content}\n" |
|
|
elif role == "assistant": |
|
|
prompt += f"Assistant: {content}\n" |
|
|
prompt += "Assistant: " |
|
|
return prompt |
|
|
|
|
|
def _remove_reasoning_tags(self, text: str) -> str: |
|
|
"""Remove Qwen reasoning tags from text.""" |
|
|
|
|
|
cleaned_text = re.sub( |
|
|
r'<think>.*?</think>', |
|
|
'', |
|
|
text, |
|
|
flags=re.DOTALL | re.IGNORECASE |
|
|
) |
|
|
|
|
|
|
|
|
if "</think>" in cleaned_text: |
|
|
parts = cleaned_text.split("</think>", 1) |
|
|
if len(parts) > 1: |
|
|
cleaned_text = parts[1].strip() |
|
|
|
|
|
|
|
|
if "<think>" in cleaned_text.lower() and "{" in cleaned_text: |
|
|
brace_pos = cleaned_text.find('{') |
|
|
if brace_pos != -1: |
|
|
cleaned_text = cleaned_text[brace_pos:] |
|
|
|
|
|
return cleaned_text |
|
|
|
|
|
def _extract_json_by_brace_matching(self, text: str, start_pos: int = 0) -> Optional[str]: |
|
|
"""Extract JSON object by matching braces starting at given position.""" |
|
|
brace_start = text.find('{', start_pos) |
|
|
if brace_start == -1: |
|
|
return None |
|
|
|
|
|
brace_count = 0 |
|
|
in_string = False |
|
|
escape_next = False |
|
|
for i in range(brace_start, len(text)): |
|
|
if escape_next: |
|
|
escape_next = False |
|
|
continue |
|
|
if text[i] == '\\': |
|
|
escape_next = True |
|
|
elif text[i] == '"' and not in_string: |
|
|
in_string = True |
|
|
elif text[i] == '"' and in_string: |
|
|
in_string = False |
|
|
elif text[i] == '{' and not in_string: |
|
|
brace_count += 1 |
|
|
elif text[i] == '}' and not in_string: |
|
|
brace_count -= 1 |
|
|
if brace_count == 0: |
|
|
json_candidate = text[brace_start:i+1] |
|
|
try: |
|
|
json.loads(json_candidate) |
|
|
return json_candidate |
|
|
except json.JSONDecodeError: |
|
|
return None |
|
|
return None |
|
|
|
|
|
def _format_tools_for_prompt(self, tools: List[Dict[str, Any]]) -> str: |
|
|
"""Format tools for inclusion in system prompt.""" |
|
|
tools_text = ( |
|
|
"CRITICAL: You have access to the following tools. When you need to use a tool, " |
|
|
"you MUST respond ONLY with the tool call format below. NO reasoning tags, NO explanations, " |
|
|
"ONLY the tool call format.\n\n" |
|
|
) |
|
|
|
|
|
for i, tool in enumerate(tools, 1): |
|
|
func = tool.get("function", {}) |
|
|
name = func.get("name", "") |
|
|
description = func.get("description", "") |
|
|
parameters = func.get("parameters", {}) |
|
|
|
|
|
tools_text += f"Tool {i}: {name}\n" |
|
|
if description: |
|
|
tools_text += f"Description: {description}\n" |
|
|
if parameters: |
|
|
tools_text += f"Parameters: {json.dumps(parameters, ensure_ascii=False, indent=2)}\n" |
|
|
tools_text += "\n" |
|
|
|
|
|
tools_text += ( |
|
|
"TO USE A TOOL, respond EXACTLY in this format (NO reasoning, NO text before or after):\n" |
|
|
"<tool_call>\n" |
|
|
'{"name": "tool_name", "arguments": {"param1": "value1", "param2": "value2"}}\n' |
|
|
"</tool_call>\n\n" |
|
|
"EXAMPLE 1 - If asked to calculate future value:\n" |
|
|
"<tool_call>\n" |
|
|
'{"name": "calculer_valeur_future", "arguments": {"capital_initial": 10000, "taux": 0.05, "duree": 10}}\n' |
|
|
"</tool_call>\n\n" |
|
|
"EXAMPLE 2 - If asked to get stock price:\n" |
|
|
"<tool_call>\n" |
|
|
'{"name": "obtenir_prix_action", "arguments": {"symbole": "AIR.PA"}}\n' |
|
|
"</tool_call>\n\n" |
|
|
"IMPORTANT: Start your response directly with <tool_call>. Do NOT include <think> tags or any reasoning. " |
|
|
"The tool call format is the ONLY thing you should output when using a tool." |
|
|
) |
|
|
|
|
|
return tools_text |
|
|
|
|
|
def _parse_tool_calls(self, generated_text: str, tools: List[Dict[str, Any]]) -> Optional[List[Dict[str, Any]]]: |
|
|
"""Parse tool calls from generated text.""" |
|
|
tool_calls = [] |
|
|
|
|
|
|
|
|
cleaned_text = self._remove_reasoning_tags(generated_text) |
|
|
|
|
|
|
|
|
pattern = r'<tool_call>\s*({.*?})\s*</tool_call>' |
|
|
matches = re.findall(pattern, cleaned_text, re.DOTALL) |
|
|
|
|
|
|
|
|
if not matches: |
|
|
|
|
|
|
|
|
json_pattern = r'\{\s*"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{[^}]+\}\s*\}' |
|
|
matches = re.findall(json_pattern, cleaned_text, re.DOTALL) |
|
|
|
|
|
|
|
|
if not matches: |
|
|
tool_names = [t.get("function", {}).get("name", "") for t in tools] |
|
|
|
|
|
brace_start = 0 |
|
|
while True: |
|
|
json_candidate = self._extract_json_by_brace_matching(cleaned_text, brace_start) |
|
|
if json_candidate is None: |
|
|
break |
|
|
try: |
|
|
candidate_data = json.loads(json_candidate) |
|
|
if "name" in candidate_data and candidate_data["name"] in tool_names: |
|
|
matches.append(json_candidate) |
|
|
break |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
|
|
|
brace_start = cleaned_text.find('{', cleaned_text.find(json_candidate) + len(json_candidate)) |
|
|
if brace_start == -1: |
|
|
break |
|
|
|
|
|
for i, match in enumerate(matches): |
|
|
try: |
|
|
call_data = json.loads(match) |
|
|
name = call_data.get("name", "") |
|
|
arguments = call_data.get("arguments", {}) |
|
|
|
|
|
|
|
|
tool_names = [t.get("function", {}).get("name", "") for t in tools] |
|
|
if name not in tool_names: |
|
|
log_warning(f"Tool call to unknown tool: {name}") |
|
|
continue |
|
|
|
|
|
|
|
|
if isinstance(arguments, dict): |
|
|
arguments_str = json.dumps(arguments, ensure_ascii=False) |
|
|
else: |
|
|
arguments_str = str(arguments) |
|
|
|
|
|
tool_calls.append({ |
|
|
"id": f"call_{os.urandom(8).hex()}", |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": name, |
|
|
"arguments": arguments_str |
|
|
} |
|
|
}) |
|
|
except json.JSONDecodeError as e: |
|
|
log_warning(f"Failed to parse tool call JSON: {e}, match: {match[:100]}") |
|
|
continue |
|
|
except Exception as e: |
|
|
log_warning(f"Error parsing tool call: {e}") |
|
|
continue |
|
|
|
|
|
return tool_calls if tool_calls else None |
|
|
|
|
|
def _clean_tool_calls_from_text(self, text: str) -> str: |
|
|
"""Remove tool call markers from text to return clean content.""" |
|
|
|
|
|
text = re.sub(r'<tool_call>.*?</tool_call>', '', text, flags=re.DOTALL) |
|
|
|
|
|
text = re.sub(r'\n\s*\n', '\n\n', text) |
|
|
return text.strip() |
|
|
|
|
|
def _extract_json_from_text(self, text: str) -> str: |
|
|
"""Extract JSON from text, handling cases where JSON is wrapped in markdown, reasoning tags, or other text.""" |
|
|
|
|
|
cleaned_text = self._remove_reasoning_tags(text) |
|
|
|
|
|
|
|
|
json_code_block = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', cleaned_text, re.DOTALL) |
|
|
if json_code_block: |
|
|
json_str = json_code_block.group(1).strip() |
|
|
try: |
|
|
json.loads(json_str) |
|
|
return json_str |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
json_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}' |
|
|
matches = re.finditer(json_pattern, cleaned_text, re.DOTALL) |
|
|
|
|
|
|
|
|
best_match = None |
|
|
best_length = 0 |
|
|
|
|
|
for match in matches: |
|
|
json_candidate = match.group(0) |
|
|
try: |
|
|
json.loads(json_candidate) |
|
|
if len(json_candidate) > best_length: |
|
|
best_match = json_candidate |
|
|
best_length = len(json_candidate) |
|
|
except json.JSONDecodeError: |
|
|
continue |
|
|
|
|
|
if best_match: |
|
|
return best_match.strip() |
|
|
|
|
|
|
|
|
json_candidate = self._extract_json_by_brace_matching(cleaned_text) |
|
|
if json_candidate: |
|
|
return json_candidate.strip() |
|
|
|
|
|
|
|
|
|
|
|
return cleaned_text.strip() |
|
|
|
|
|
|
|
|
|
|
|
_provider = TransformersProvider() |
|
|
|
|
|
|
|
|
def is_model_ready() -> bool: |
|
|
""" |
|
|
Thread-safe check if the model is loaded and ready for inference. |
|
|
|
|
|
Returns: |
|
|
True if model is initialized and loaded, False otherwise. |
|
|
""" |
|
|
with _init_lock: |
|
|
return _initialized and model is not None and tokenizer is not None |
|
|
|
|
|
|
|
|
|
|
|
async def list_models() -> Dict[str, Any]: |
|
|
"""List available models.""" |
|
|
return await _provider.list_models() |
|
|
|
|
|
|
|
|
async def chat(payload: Dict[str, Any], stream: bool = False) -> Union[Dict[str, Any], AsyncIterator[str]]: |
|
|
"""Chat completion.""" |
|
|
return await _provider.chat(payload, stream=stream) |
|
|
|