Spaces:
Sleeping
Sleeping
update
Browse files- fastapi_app.py +16 -5
- llm/common.py +64 -0
- llm/deepinfra_api.py +153 -0
- llm/vllm_api.py +33 -65
- transaction_maps_search.py +9 -12
fastapi_app.py
CHANGED
|
@@ -6,7 +6,8 @@ import os
|
|
| 6 |
import datetime
|
| 7 |
import json
|
| 8 |
import traceback
|
| 9 |
-
from llm.
|
|
|
|
| 10 |
|
| 11 |
# Set the path for log files
|
| 12 |
LOGS_BASE_PATH = os.getenv("LOGS_BASE_PATH", "logs")
|
|
@@ -17,7 +18,9 @@ LOGS_BASE_PATH = os.getenv("LOGS_BASE_PATH", "logs")
|
|
| 17 |
|
| 18 |
# Check if logs are enabled
|
| 19 |
ENABLE_LOGS = os.getenv("ENABLE_LOGS", "0") == "1"
|
| 20 |
-
|
|
|
|
|
|
|
| 21 |
|
| 22 |
class Query(BaseModel):
|
| 23 |
query: str = ''
|
|
@@ -87,18 +90,26 @@ async def search_route(query: Query) -> dict:
|
|
| 87 |
|
| 88 |
llm_params = getattr(query, "llm_params", None)
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
if find_transaction_maps_by_question or find_transaction_maps_by_operation:
|
| 91 |
-
transaction_maps_results, answer = transaction_maps_search.search_transaction_map(
|
| 92 |
query=question,
|
| 93 |
find_transaction_maps_by_question=find_transaction_maps_by_question,
|
| 94 |
-
k_neighbours=top
|
|
|
|
| 95 |
|
| 96 |
response = {'transaction_maps_results': transaction_maps_results}
|
| 97 |
|
| 98 |
else:
|
| 99 |
modified_query, titles, concat_docs, \
|
| 100 |
relevant_consultations, predicted_explanation, \
|
| 101 |
-
llm_responses = await search.search(question, use_qe, use_olympic, categories,
|
| 102 |
|
| 103 |
results = [{'title': str(item1), 'text_for_llm': str(item2)} for item1, item2 in
|
| 104 |
zip(titles, concat_docs)]
|
|
|
|
| 6 |
import datetime
|
| 7 |
import json
|
| 8 |
import traceback
|
| 9 |
+
from llm.common import LlmParams, LlmPredictParams
|
| 10 |
+
from llm.deepinfra_api import DeepInfraApi
|
| 11 |
|
| 12 |
# Set the path for log files
|
| 13 |
LOGS_BASE_PATH = os.getenv("LOGS_BASE_PATH", "logs")
|
|
|
|
| 18 |
|
| 19 |
# Check if logs are enabled
|
| 20 |
ENABLE_LOGS = os.getenv("ENABLE_LOGS", "0") == "1"
|
| 21 |
+
LLM_API_URL = os.getenv("LLM_API_URL", "")
|
| 22 |
+
LLM_API_KEY = os.getenv("LLM_API_KEY", "")
|
| 23 |
+
LLM_USE_DEEPINFRA = os.getenv("LLM_USE_DEEPINFRA", "") == "1"
|
| 24 |
|
| 25 |
class Query(BaseModel):
|
| 26 |
query: str = ''
|
|
|
|
| 90 |
|
| 91 |
llm_params = getattr(query, "llm_params", None)
|
| 92 |
|
| 93 |
+
if llm_params is None:
|
| 94 |
+
llm_params = LlmParams(url=LLM_API_URL,api_key=LLM_API_KEY, model="mistralai/Mixtral-8x7B-Instruct-v0.1", predict_params=LlmPredictParams(temperature=0.15, top_p=0.95, min_p=0.05, seed=42, repetition_penalty=1.2, presence_penalty=1.1, max_tokens=6000))
|
| 95 |
+
|
| 96 |
+
if LLM_USE_DEEPINFRA:
|
| 97 |
+
llm_api = DeepInfraApi(llm_params)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
if find_transaction_maps_by_question or find_transaction_maps_by_operation:
|
| 101 |
+
transaction_maps_results, answer = await transaction_maps_search.search_transaction_map(
|
| 102 |
query=question,
|
| 103 |
find_transaction_maps_by_question=find_transaction_maps_by_question,
|
| 104 |
+
k_neighbours=top,
|
| 105 |
+
llm_api=llm_api)
|
| 106 |
|
| 107 |
response = {'transaction_maps_results': transaction_maps_results}
|
| 108 |
|
| 109 |
else:
|
| 110 |
modified_query, titles, concat_docs, \
|
| 111 |
relevant_consultations, predicted_explanation, \
|
| 112 |
+
llm_responses = await search.search(question, use_qe, use_olympic, categories, llm_params)
|
| 113 |
|
| 114 |
results = [{'title': str(item1), 'text_for_llm': str(item2)} for item1, item2 in
|
| 115 |
zip(titles, concat_docs)]
|
llm/common.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import Optional, List, Protocol
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
|
| 5 |
+
class LlmPredictParams(BaseModel):
|
| 6 |
+
"""
|
| 7 |
+
Параметры для предсказания LLM.
|
| 8 |
+
"""
|
| 9 |
+
system_prompt: Optional[str] = Field(None, description="Системный промпт.")
|
| 10 |
+
user_prompt: Optional[str] = Field(None, description="Шаблон промпта для передачи от роли user.")
|
| 11 |
+
n_predict: Optional[int] = None
|
| 12 |
+
temperature: Optional[float] = None
|
| 13 |
+
top_k: Optional[int] = None
|
| 14 |
+
top_p: Optional[float] = None
|
| 15 |
+
min_p: Optional[float] = None
|
| 16 |
+
seed: Optional[int] = None
|
| 17 |
+
repeat_penalty: Optional[float] = None
|
| 18 |
+
repeat_last_n: Optional[int] = None
|
| 19 |
+
retry_if_text_not_present: Optional[str] = None
|
| 20 |
+
retry_count: Optional[int] = None
|
| 21 |
+
presence_penalty: Optional[float] = None
|
| 22 |
+
frequency_penalty: Optional[float] = None
|
| 23 |
+
n_keep: Optional[int] = None
|
| 24 |
+
cache_prompt: Optional[bool] = None
|
| 25 |
+
stop: Optional[List[str]] = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class LlmParams(BaseModel):
|
| 29 |
+
"""
|
| 30 |
+
Основные параметры для LLM.
|
| 31 |
+
"""
|
| 32 |
+
url: str
|
| 33 |
+
model: Optional[str] = Field(None, description="Предполагается, что для локального API этот параметр не будет указываться, т.к. будем брать первую модель из списка потому, что модель доступна всего одна. Для deepinfra такой подход не подойдет и модель нужно задавать явно.")
|
| 34 |
+
type: Optional[str] = None
|
| 35 |
+
default: Optional[bool] = None
|
| 36 |
+
template: Optional[str] = None
|
| 37 |
+
predict_params: Optional[LlmPredictParams] = None
|
| 38 |
+
api_key: Optional[str] = None
|
| 39 |
+
|
| 40 |
+
class LlmApiProtocol(Protocol):
|
| 41 |
+
async def tokenize(self, prompt: str) -> Optional[dict]:
|
| 42 |
+
...
|
| 43 |
+
async def detokenize(self, tokens: List[int]) -> Optional[str]:
|
| 44 |
+
...
|
| 45 |
+
async def trim_sources(self, sources: str, user_request: str, system_prompt: str = None) -> dict:
|
| 46 |
+
...
|
| 47 |
+
async def predict(self, prompt: str) -> str:
|
| 48 |
+
...
|
| 49 |
+
|
| 50 |
+
class LlmApi:
|
| 51 |
+
"""
|
| 52 |
+
Базовый клас для работы с API LLM.
|
| 53 |
+
"""
|
| 54 |
+
params: LlmParams = None
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def create_headers(self) -> dict[str, str]:
|
| 58 |
+
headers = {"Content-Type": "application/json"}
|
| 59 |
+
|
| 60 |
+
if self.params.api_key is not None:
|
| 61 |
+
headers["Authorization"] = self.params.api_key
|
| 62 |
+
|
| 63 |
+
return headers
|
| 64 |
+
|
llm/deepinfra_api.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from typing import Optional, List
|
| 3 |
+
import httpx
|
| 4 |
+
from common import LlmPredictParams, LlmParams, LlmApi
|
| 5 |
+
|
| 6 |
+
class DeepInfraApi(LlmApi):
|
| 7 |
+
"""
|
| 8 |
+
Класс для работы с API vllm.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, params: LlmParams):
|
| 12 |
+
super.params = params
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
async def get_models(self) -> List[str]:
|
| 16 |
+
"""
|
| 17 |
+
Выполняет GET-запрос к API для получения списка доступных моделей.
|
| 18 |
+
|
| 19 |
+
Возвращает:
|
| 20 |
+
list[str]: Список идентификаторов моделей.
|
| 21 |
+
Если произошла ошибка или данные недоступны, возвращается пустой список.
|
| 22 |
+
|
| 23 |
+
Исключения:
|
| 24 |
+
Все ошибки HTTP-запросов логируются в консоль, но не выбрасываются дальше.
|
| 25 |
+
"""
|
| 26 |
+
try:
|
| 27 |
+
async with httpx.AsyncClient() as client:
|
| 28 |
+
response = await client.get(f"{super.params.url}/v1/openai/models", super.create_headers())
|
| 29 |
+
if response.status_code == 200:
|
| 30 |
+
json_data = response.json()
|
| 31 |
+
return [item['id'] for item in json_data.get('data', [])]
|
| 32 |
+
except httpx.RequestError as error:
|
| 33 |
+
print('Error fetching models:', error)
|
| 34 |
+
return []
|
| 35 |
+
|
| 36 |
+
def create_messages(self, prompt: str) -> List[dict]:
|
| 37 |
+
"""
|
| 38 |
+
Создает сообщения для LLM на основе переданного промпта и системного промпта (если он задан).
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
prompt (str): Пользовательский промпт.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
list[dict]: Список сообщений с ролями и содержимым.
|
| 45 |
+
"""
|
| 46 |
+
actual_prompt = self.apply_llm_template_to_prompt(prompt)
|
| 47 |
+
messages = []
|
| 48 |
+
if super.params.predict_params and super.params.predict_params.system_prompt:
|
| 49 |
+
messages.append({"role": "system", "content": super.params.predict_params.system_prompt})
|
| 50 |
+
messages.append({"role": "user", "content": actual_prompt})
|
| 51 |
+
return messages
|
| 52 |
+
|
| 53 |
+
def apply_llm_template_to_prompt(self, prompt: str) -> str:
|
| 54 |
+
"""
|
| 55 |
+
Применяет шаблон LLM к переданному промпту, если он задан.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
prompt (str): Пользовательский промпт.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
str: Промпт с примененным шаблоном (или оригинальный, если шаблон отсутствует).
|
| 62 |
+
"""
|
| 63 |
+
actual_prompt = prompt
|
| 64 |
+
if super.params.template is not None:
|
| 65 |
+
actual_prompt = super.params.template.replace("{{PROMPT}}", actual_prompt)
|
| 66 |
+
return actual_prompt
|
| 67 |
+
|
| 68 |
+
async def tokenize(self, prompt: str) -> Optional[dict]:
|
| 69 |
+
raise NotImplementedError("This function is not supported.")
|
| 70 |
+
|
| 71 |
+
async def detokenize(self, tokens: List[int]) -> Optional[str]:
|
| 72 |
+
raise NotImplementedError("This function is not supported.")
|
| 73 |
+
|
| 74 |
+
async def create_request(self, prompt: str) -> dict:
|
| 75 |
+
"""
|
| 76 |
+
Создает запрос для предсказания на основе параметров LLM.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
prompt (str): Промпт для запроса.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
dict: Словарь с параметрами для выполнения запроса.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
request = {
|
| 86 |
+
"stream": False,
|
| 87 |
+
"model": super.params.model,
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
predict_params = super.params.predict_params
|
| 91 |
+
if predict_params:
|
| 92 |
+
if predict_params.stop:
|
| 93 |
+
non_empty_stop = list(filter(lambda o: o != "", predict_params.stop))
|
| 94 |
+
if non_empty_stop:
|
| 95 |
+
request["stop"] = non_empty_stop
|
| 96 |
+
|
| 97 |
+
if predict_params.n_predict is not None:
|
| 98 |
+
request["max_tokens"] = int(predict_params.n_predict or 0)
|
| 99 |
+
|
| 100 |
+
request["temperature"] = float(predict_params.temperature or 0)
|
| 101 |
+
if predict_params.top_k is not None:
|
| 102 |
+
request["top_k"] = int(predict_params.top_k)
|
| 103 |
+
|
| 104 |
+
if predict_params.top_p is not None:
|
| 105 |
+
request["top_p"] = float(predict_params.top_p)
|
| 106 |
+
|
| 107 |
+
if predict_params.min_p is not None:
|
| 108 |
+
request["min_p"] = float(predict_params.min_p)
|
| 109 |
+
|
| 110 |
+
if predict_params.seed is not None:
|
| 111 |
+
request["seed"] = int(predict_params.seed)
|
| 112 |
+
|
| 113 |
+
if predict_params.n_keep is not None:
|
| 114 |
+
request["n_keep"] = int(predict_params.n_keep)
|
| 115 |
+
|
| 116 |
+
if predict_params.cache_prompt is not None:
|
| 117 |
+
request["cache_prompt"] = bool(predict_params.cache_prompt)
|
| 118 |
+
|
| 119 |
+
if predict_params.repeat_penalty is not None:
|
| 120 |
+
request["repetition_penalty"] = float(predict_params.repeat_penalty)
|
| 121 |
+
|
| 122 |
+
if predict_params.repeat_last_n is not None:
|
| 123 |
+
request["repeat_last_n"] = int(predict_params.repeat_last_n)
|
| 124 |
+
|
| 125 |
+
if predict_params.presence_penalty is not None:
|
| 126 |
+
request["presence_penalty"] = float(predict_params.presence_penalty)
|
| 127 |
+
|
| 128 |
+
if predict_params.frequency_penalty is not None:
|
| 129 |
+
request["frequency_penalty"] = float(predict_params.frequency_penalty)
|
| 130 |
+
|
| 131 |
+
request["messages"] = self.create_messages(prompt)
|
| 132 |
+
return request
|
| 133 |
+
|
| 134 |
+
async def trim_sources(self, sources: str, user_request: str, system_prompt: str = None) -> dict:
|
| 135 |
+
raise NotImplementedError("This function is not supported.")
|
| 136 |
+
|
| 137 |
+
async def predict(self, prompt: str) -> str:
|
| 138 |
+
"""
|
| 139 |
+
Выполняет запрос к API и возвращает результат.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
prompt (str): Входной текст для предсказания.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
str: Сгенерированный текст.
|
| 146 |
+
"""
|
| 147 |
+
async with httpx.AsyncClient() as client:
|
| 148 |
+
request = await self.create_request(prompt)
|
| 149 |
+
|
| 150 |
+
async with httpx.AsyncClient() as client:
|
| 151 |
+
response = client.post(f"{super.params.url}/v1/openai/chat/completions", super.create_headers(), json=request)
|
| 152 |
+
if response.status_code == 200:
|
| 153 |
+
return response.json()["choices"][0]["message"]["content"]
|
llm/vllm_api.py
CHANGED
|
@@ -3,51 +3,17 @@ from typing import Optional, List, Any
|
|
| 3 |
|
| 4 |
import httpx
|
| 5 |
from pydantic import BaseModel, Field
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
-
class
|
| 9 |
-
"""
|
| 10 |
-
Параметры для предсказания LLM.
|
| 11 |
-
"""
|
| 12 |
-
system_prompt: Optional[str] = Field(None, description="Системный промпт.")
|
| 13 |
-
user_prompt: Optional[str] = Field(None, description="Шаблон промпта для передачи от роли user.")
|
| 14 |
-
n_predict: Optional[int] = None
|
| 15 |
-
temperature: Optional[float] = None
|
| 16 |
-
top_k: Optional[int] = None
|
| 17 |
-
top_p: Optional[float] = None
|
| 18 |
-
min_p: Optional[float] = None
|
| 19 |
-
seed: Optional[int] = None
|
| 20 |
-
repeat_penalty: Optional[float] = None
|
| 21 |
-
repeat_last_n: Optional[int] = None
|
| 22 |
-
retry_if_text_not_present: Optional[str] = None
|
| 23 |
-
retry_count: Optional[int] = None
|
| 24 |
-
presence_penalty: Optional[float] = None
|
| 25 |
-
frequency_penalty: Optional[float] = None
|
| 26 |
-
n_keep: Optional[int] = None
|
| 27 |
-
cache_prompt: Optional[bool] = None
|
| 28 |
-
stop: Optional[List[str]] = None
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
class LlmParams(BaseModel):
|
| 32 |
-
"""
|
| 33 |
-
Основные параметры для LLM.
|
| 34 |
-
"""
|
| 35 |
-
url: str
|
| 36 |
-
type: Optional[str] = None
|
| 37 |
-
default: Optional[bool] = None
|
| 38 |
-
template: Optional[str] = None
|
| 39 |
-
predict_params: Optional[LlmPredictParams] = None
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class LlmApi:
|
| 43 |
"""
|
| 44 |
Класс для работы с API vllm.
|
| 45 |
"""
|
| 46 |
-
params: LlmParams = None
|
| 47 |
|
| 48 |
def __init__(self, params: LlmParams):
|
| 49 |
-
|
| 50 |
-
|
| 51 |
async def get_models(self) -> List[str]:
|
| 52 |
"""
|
| 53 |
Выполняет GET-запрос к API для получения списка доступных моделей.
|
|
@@ -61,13 +27,26 @@ class LlmApi:
|
|
| 61 |
"""
|
| 62 |
try:
|
| 63 |
async with httpx.AsyncClient() as client:
|
| 64 |
-
response = await client.get(f"{
|
| 65 |
if response.status_code == 200:
|
| 66 |
json_data = response.json()
|
| 67 |
return [item['id'] for item in json_data.get('data', [])]
|
| 68 |
except httpx.RequestError as error:
|
| 69 |
print('Error fetching models:', error)
|
| 70 |
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
def create_messages(self, prompt: str) -> List[dict]:
|
| 73 |
"""
|
|
@@ -81,8 +60,8 @@ class LlmApi:
|
|
| 81 |
"""
|
| 82 |
actual_prompt = self.apply_llm_template_to_prompt(prompt)
|
| 83 |
messages = []
|
| 84 |
-
if
|
| 85 |
-
messages.append({"role": "system", "content":
|
| 86 |
messages.append({"role": "user", "content": actual_prompt})
|
| 87 |
return messages
|
| 88 |
|
|
@@ -97,8 +76,8 @@ class LlmApi:
|
|
| 97 |
str: Промпт с примененным шаблоном (или оригинальный, если шаблон отсутствует).
|
| 98 |
"""
|
| 99 |
actual_prompt = prompt
|
| 100 |
-
if
|
| 101 |
-
actual_prompt =
|
| 102 |
return actual_prompt
|
| 103 |
|
| 104 |
async def tokenize(self, prompt: str) -> Optional[dict]:
|
|
@@ -112,14 +91,10 @@ class LlmApi:
|
|
| 112 |
Optional[dict]: Словарь с токенами и максимальной длиной модели, если запрос успешен.
|
| 113 |
Если запрос неуспешен, возвращает None.
|
| 114 |
"""
|
| 115 |
-
model = (await self.get_models())[0] if await self.get_models() else None
|
| 116 |
-
if not model:
|
| 117 |
-
print("No models available for tokenization.")
|
| 118 |
-
return None
|
| 119 |
|
| 120 |
actual_prompt = self.apply_llm_template_to_prompt(prompt)
|
| 121 |
request_data = {
|
| 122 |
-
"model":
|
| 123 |
"prompt": actual_prompt,
|
| 124 |
"add_special_tokens": False,
|
| 125 |
}
|
|
@@ -127,9 +102,9 @@ class LlmApi:
|
|
| 127 |
try:
|
| 128 |
async with httpx.AsyncClient() as client:
|
| 129 |
response = await client.post(
|
| 130 |
-
f"{
|
| 131 |
json=request_data,
|
| 132 |
-
headers=
|
| 133 |
)
|
| 134 |
if response.status_code == 200:
|
| 135 |
data = response.json()
|
|
@@ -155,19 +130,15 @@ class LlmApi:
|
|
| 155 |
Optional[str]: Строка, полученная в результате детокенизации, если запрос успешен.
|
| 156 |
Если запрос неуспешен, возвращает None.
|
| 157 |
"""
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
print("No models available for detokenization.")
|
| 161 |
-
return None
|
| 162 |
-
|
| 163 |
-
request_data = {"model": model, "tokens": tokens or []}
|
| 164 |
|
| 165 |
try:
|
| 166 |
async with httpx.AsyncClient() as client:
|
| 167 |
response = await client.post(
|
| 168 |
-
f"{
|
| 169 |
json=request_data,
|
| 170 |
-
headers=
|
| 171 |
)
|
| 172 |
if response.status_code == 200:
|
| 173 |
data = response.json()
|
|
@@ -192,17 +163,14 @@ class LlmApi:
|
|
| 192 |
Returns:
|
| 193 |
dict: Словарь с параметрами для выполнения запроса.
|
| 194 |
"""
|
| 195 |
-
|
| 196 |
-
if not models:
|
| 197 |
-
raise ValueError("No models available to create a request.")
|
| 198 |
-
model = models[0]
|
| 199 |
|
| 200 |
request = {
|
| 201 |
"stream": True,
|
| 202 |
"model": model,
|
| 203 |
}
|
| 204 |
|
| 205 |
-
predict_params =
|
| 206 |
if predict_params:
|
| 207 |
if predict_params.stop:
|
| 208 |
non_empty_stop = list(filter(lambda o: o != "", predict_params.stop))
|
|
@@ -283,7 +251,7 @@ class LlmApi:
|
|
| 283 |
# Максимально допустимое количество токенов для источников
|
| 284 |
max_length = (
|
| 285 |
max_token_count
|
| 286 |
-
- (
|
| 287 |
- aux_token_count
|
| 288 |
- system_prompt_token_count
|
| 289 |
)
|
|
@@ -322,7 +290,7 @@ class LlmApi:
|
|
| 322 |
request = await self.create_request(prompt)
|
| 323 |
|
| 324 |
# Начинаем потоковый запрос
|
| 325 |
-
async with client.stream("POST", f"{
|
| 326 |
if response.status_code != 200:
|
| 327 |
# Если ошибка, читаем ответ для получения подробностей
|
| 328 |
error_content = await response.aread()
|
|
|
|
| 3 |
|
| 4 |
import httpx
|
| 5 |
from pydantic import BaseModel, Field
|
| 6 |
+
from common import LlmPredictParams, LlmParams, LlmApi
|
| 7 |
|
| 8 |
|
| 9 |
+
class LlmApi(LlmApi):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
"""
|
| 11 |
Класс для работы с API vllm.
|
| 12 |
"""
|
|
|
|
| 13 |
|
| 14 |
def __init__(self, params: LlmParams):
|
| 15 |
+
super.params = params
|
| 16 |
+
|
| 17 |
async def get_models(self) -> List[str]:
|
| 18 |
"""
|
| 19 |
Выполняет GET-запрос к API для получения списка доступных моделей.
|
|
|
|
| 27 |
"""
|
| 28 |
try:
|
| 29 |
async with httpx.AsyncClient() as client:
|
| 30 |
+
response = await client.get(f"{super.params.url}/v1/models", super.create_headers())
|
| 31 |
if response.status_code == 200:
|
| 32 |
json_data = response.json()
|
| 33 |
return [item['id'] for item in json_data.get('data', [])]
|
| 34 |
except httpx.RequestError as error:
|
| 35 |
print('Error fetching models:', error)
|
| 36 |
return []
|
| 37 |
+
|
| 38 |
+
async def get_model(self) -> str:
|
| 39 |
+
model = None
|
| 40 |
+
if super.params.model is not None:
|
| 41 |
+
model = super.params.model
|
| 42 |
+
else:
|
| 43 |
+
models = await self.get_models()
|
| 44 |
+
model = models[0] if models else None
|
| 45 |
+
|
| 46 |
+
if model is None:
|
| 47 |
+
raise Exception("No model name provided and no models available.")
|
| 48 |
+
|
| 49 |
+
return model
|
| 50 |
|
| 51 |
def create_messages(self, prompt: str) -> List[dict]:
|
| 52 |
"""
|
|
|
|
| 60 |
"""
|
| 61 |
actual_prompt = self.apply_llm_template_to_prompt(prompt)
|
| 62 |
messages = []
|
| 63 |
+
if super.params.predict_params and super.params.predict_params.system_prompt:
|
| 64 |
+
messages.append({"role": "system", "content": super.params.predict_params.system_prompt})
|
| 65 |
messages.append({"role": "user", "content": actual_prompt})
|
| 66 |
return messages
|
| 67 |
|
|
|
|
| 76 |
str: Промпт с примененным шаблоном (или оригинальный, если шаблон отсутствует).
|
| 77 |
"""
|
| 78 |
actual_prompt = prompt
|
| 79 |
+
if super.params.template is not None:
|
| 80 |
+
actual_prompt = super.params.template.replace("{{PROMPT}}", actual_prompt)
|
| 81 |
return actual_prompt
|
| 82 |
|
| 83 |
async def tokenize(self, prompt: str) -> Optional[dict]:
|
|
|
|
| 91 |
Optional[dict]: Словарь с токенами и максимальной длиной модели, если запрос успешен.
|
| 92 |
Если запрос неуспешен, возвращает None.
|
| 93 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
actual_prompt = self.apply_llm_template_to_prompt(prompt)
|
| 96 |
request_data = {
|
| 97 |
+
"model": self.get_model(),
|
| 98 |
"prompt": actual_prompt,
|
| 99 |
"add_special_tokens": False,
|
| 100 |
}
|
|
|
|
| 102 |
try:
|
| 103 |
async with httpx.AsyncClient() as client:
|
| 104 |
response = await client.post(
|
| 105 |
+
f"{super.params.url}/tokenize",
|
| 106 |
json=request_data,
|
| 107 |
+
headers=super.create_headers(),
|
| 108 |
)
|
| 109 |
if response.status_code == 200:
|
| 110 |
data = response.json()
|
|
|
|
| 130 |
Optional[str]: Строка, полученная в результате детокенизации, если запрос успешен.
|
| 131 |
Если запрос неуспешен, возвращает None.
|
| 132 |
"""
|
| 133 |
+
|
| 134 |
+
request_data = {"model": self.get_model(), "tokens": tokens or []}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
try:
|
| 137 |
async with httpx.AsyncClient() as client:
|
| 138 |
response = await client.post(
|
| 139 |
+
f"{super.params.url}/detokenize",
|
| 140 |
json=request_data,
|
| 141 |
+
headers=super.create_headers(),
|
| 142 |
)
|
| 143 |
if response.status_code == 200:
|
| 144 |
data = response.json()
|
|
|
|
| 163 |
Returns:
|
| 164 |
dict: Словарь с параметрами для выполнения запроса.
|
| 165 |
"""
|
| 166 |
+
model = self.get_model()
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
request = {
|
| 169 |
"stream": True,
|
| 170 |
"model": model,
|
| 171 |
}
|
| 172 |
|
| 173 |
+
predict_params = super.params.predict_params
|
| 174 |
if predict_params:
|
| 175 |
if predict_params.stop:
|
| 176 |
non_empty_stop = list(filter(lambda o: o != "", predict_params.stop))
|
|
|
|
| 251 |
# Максимально допустимое количество токенов для источников
|
| 252 |
max_length = (
|
| 253 |
max_token_count
|
| 254 |
+
- (super.params.predict_params.n_predict or 0)
|
| 255 |
- aux_token_count
|
| 256 |
- system_prompt_token_count
|
| 257 |
)
|
|
|
|
| 290 |
request = await self.create_request(prompt)
|
| 291 |
|
| 292 |
# Начинаем потоковый запрос
|
| 293 |
+
async with client.stream("POST", f"{super.params.url}/v1/chat/completions", json=request) as response:
|
| 294 |
if response.status_code != 200:
|
| 295 |
# Если ошибка, читаем ответ для получения подробностей
|
| 296 |
error_content = await response.aread()
|
transaction_maps_search.py
CHANGED
|
@@ -3,14 +3,13 @@ from business_transaction_map.common.constants import DEVICE, DO_NORMALIZATION,
|
|
| 3 |
from business_transaction_map.components.faiss_vector_database import FaissVectorDatabase
|
| 4 |
from business_transaction_map.components.embedding_extraction import EmbeddingExtractor
|
| 5 |
import os
|
| 6 |
-
import requests
|
| 7 |
from prompts import BUSINESS_TRANSACTION_PROMPT
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
db_files_path = os.environ.get("GLOBAL_TRANSACTION_MAPS_DATA_PATH", "transaction_maps_search_data/csv/карта_проводок_new.pkl")
|
| 11 |
|
| 12 |
model_path = os.environ.get("GLOBAL_TRANSACTION_MAPS_MODEL_PATH", "")
|
| 13 |
-
llm_api_endpoint = os.environ.get("LLM_API_ENDPOINT", "")
|
| 14 |
|
| 15 |
class TransactionMapsSearch:
|
| 16 |
|
|
@@ -26,14 +25,11 @@ class TransactionMapsSearch:
|
|
| 26 |
self.database = FaissVectorDatabase(str(db_files_path))
|
| 27 |
|
| 28 |
@staticmethod
|
| 29 |
-
def extract_business_transaction_with_llm(question: str) -> str:
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
response = requests.post(url=llm_api_endpoint,
|
| 34 |
-
json={"prompt": f"[INST] {question} [/INST]", #пробелы внутри [INST], как оказалось, важны. Без них можно словить бесконечную генерацию бреда от ллм
|
| 35 |
-
"temperature": 0.0})
|
| 36 |
-
return response.json()['content']
|
| 37 |
|
| 38 |
|
| 39 |
@staticmethod
|
|
@@ -66,13 +62,14 @@ class TransactionMapsSearch:
|
|
| 66 |
return answer
|
| 67 |
|
| 68 |
|
| 69 |
-
def search_transaction_map(self,
|
| 70 |
query: str = None,
|
| 71 |
find_transaction_maps_by_question: bool = False,
|
| 72 |
-
k_neighbours: int = 15
|
|
|
|
| 73 |
|
| 74 |
if find_transaction_maps_by_question:
|
| 75 |
-
query = self.extract_business_transaction_with_llm(query)
|
| 76 |
cleaned_text = query.replace("\n", " ")
|
| 77 |
# cleaned_text = 'query: ' + cleaned_text # only for e5
|
| 78 |
query_tokens = self.model.query_tokenization(cleaned_text)
|
|
|
|
| 3 |
from business_transaction_map.components.faiss_vector_database import FaissVectorDatabase
|
| 4 |
from business_transaction_map.components.embedding_extraction import EmbeddingExtractor
|
| 5 |
import os
|
|
|
|
| 6 |
from prompts import BUSINESS_TRANSACTION_PROMPT
|
| 7 |
+
from llm.common import LlmApi
|
| 8 |
|
| 9 |
|
| 10 |
db_files_path = os.environ.get("GLOBAL_TRANSACTION_MAPS_DATA_PATH", "transaction_maps_search_data/csv/карта_проводок_new.pkl")
|
| 11 |
|
| 12 |
model_path = os.environ.get("GLOBAL_TRANSACTION_MAPS_MODEL_PATH", "")
|
|
|
|
| 13 |
|
| 14 |
class TransactionMapsSearch:
|
| 15 |
|
|
|
|
| 25 |
self.database = FaissVectorDatabase(str(db_files_path))
|
| 26 |
|
| 27 |
@staticmethod
|
| 28 |
+
async def extract_business_transaction_with_llm(question: str, llm_api: LlmApi) -> str:
|
| 29 |
+
prompt = BUSINESS_TRANSACTION_PROMPT.replace('{{ЗАПРОС}}', question)
|
| 30 |
+
res = await llm_api.predict(prompt)
|
| 31 |
|
| 32 |
+
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
@staticmethod
|
|
|
|
| 62 |
return answer
|
| 63 |
|
| 64 |
|
| 65 |
+
async def search_transaction_map(self,
|
| 66 |
query: str = None,
|
| 67 |
find_transaction_maps_by_question: bool = False,
|
| 68 |
+
k_neighbours: int = 15,
|
| 69 |
+
llm_api: LlmApi = None):
|
| 70 |
|
| 71 |
if find_transaction_maps_by_question:
|
| 72 |
+
query = await self.extract_business_transaction_with_llm(query, llm_api)
|
| 73 |
cleaned_text = query.replace("\n", " ")
|
| 74 |
# cleaned_text = 'query: ' + cleaned_text # only for e5
|
| 75 |
query_tokens = self.model.query_tokenization(cleaned_text)
|