Spaces:
Sleeping
Sleeping
File size: 6,204 Bytes
01d5a5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
from typing import Dict, List, Optional, Union, Any
import logging
from openai import OpenAI
from lpm_kernel.L1.bio import Bio, Chat, Note, Todo, UserInfo
from lpm_kernel.L1.prompt import PREFER_LANGUAGE_SYSTEM_PROMPT, STATUS_BIO_SYSTEM_PROMPT
from lpm_kernel.L1.utils import get_cur_time, is_valid_chat, is_valid_note, is_valid_todo
from lpm_kernel.api.services.user_llm_config_service import UserLLMConfigService
from lpm_kernel.configs.config import Config
from lpm_kernel.configs.logging import get_train_process_logger
logger = get_train_process_logger()
class StatusBioGenerator:
def __init__(self):
self.preferred_language = "English"
self.model_params = {
"temperature": 0,
"max_tokens": 1000,
"top_p": 0,
"frequency_penalty": 0,
"presence_penalty": 0,
"seed": 42,
}
self.user_llm_config_service = UserLLMConfigService()
self.user_llm_config = self.user_llm_config_service.get_available_llm()
if self.user_llm_config is None:
self.client = None
self.model_name = None
else:
self.client = OpenAI(
api_key=self.user_llm_config.chat_api_key,
base_url=self.user_llm_config.chat_endpoint,
timeout=45.0, # Set global timeout
)
self.model_name = self.user_llm_config.chat_model_name
self._top_p_adjusted = False # Flag to track if top_p has been adjusted
def _fix_top_p_param(self, error_message: str) -> bool:
"""Fixes the top_p parameter if an API error indicates it's invalid.
Some LLM providers don't accept top_p=0 and require values in specific ranges.
This function checks if the error is related to top_p and adjusts it to 0.001,
which is close enough to 0 to maintain deterministic behavior while satisfying
API requirements.
Args:
error_message: Error message from the API response.
Returns:
bool: True if top_p was adjusted, False otherwise.
"""
if not self._top_p_adjusted and "top_p" in error_message.lower():
logger.warning("Fixing top_p parameter from 0 to 0.001 to comply with model API requirements")
self.model_params["top_p"] = 0.001
self._top_p_adjusted = True
return True
return False
def _call_llm_with_retry(self, messages: List[Dict[str, str]], **kwargs) -> Any:
"""Calls the LLM API with automatic retry for parameter adjustments.
This function handles making API calls to the language model while
implementing automatic parameter fixes when errors occur. If the API
rejects the call due to invalid top_p parameter, it will adjust the
parameter value and retry the call once.
Args:
messages: List of messages for the API call.
**kwargs: Additional parameters to pass to the API call.
Returns:
API response object from the language model.
Raises:
Exception: If the API call fails after all retries or for unrelated errors.
"""
try:
return self.client.chat.completions.create(
model=self.model_name,
messages=messages,
**self.model_params,
**kwargs
)
except Exception as e:
error_msg = str(e)
logger.error(f"API Error: {error_msg}")
# Try to fix top_p parameter if needed
if hasattr(e, 'response') and hasattr(e.response, 'status_code') and e.response.status_code == 400:
if self._fix_top_p_param(error_msg):
logger.info("Retrying LLM API call with adjusted top_p parameter")
return self.client.chat.completions.create(
model=self.model_name,
messages=messages,
**self.model_params,
**kwargs
)
# Re-raise the exception
raise
def _build_message(self, user_info: UserInfo, language: str) -> List[Dict[str, str]]:
"""Build message list for generating status biography.
Args:
user_info: User information object.
language: Preferred language.
Returns:
List of messages formatted for LLM API.
"""
messages = [
{"role": "system", "content": STATUS_BIO_SYSTEM_PROMPT},
{"role": "user", "content": str(user_info)},
]
if language:
messages.append(
{
"role": "system",
"content": PREFER_LANGUAGE_SYSTEM_PROMPT.format(language=language),
}
)
return messages
def generate_status_bio(self, notes: List[Note], todos: List[Todo],
chats: List[Chat]) -> Bio:
"""Generate a status biography based on user's notes, todos, and chats.
Args:
notes: List of user's notes.
todos: List of user's todos.
chats: List of user's chats.
Returns:
Bio object containing generated content.
"""
cur_time = get_cur_time()
user_info = UserInfo(cur_time, notes, todos, chats)
messages = self._build_message(user_info, self.preferred_language)
answer = self._call_llm_with_retry(messages)
content = answer.choices[0].message.content
logger.info(f"Generated content: {content}")
# Create and return Bio object, ensuring all content fields have values
return Bio(
contentThirdView=content, # Put generated content in third_view
content=content, # Put generated content in second_view
summaryThirdView=content, # Put generated content in third_view
summary=content, # Put generated content in second_view
attributeList=[],
shadesList=[],
)
|