Spaces:
Runtime error
Runtime error
update
Browse files- common/auth.py +7 -3
- components/llm/deepinfra_api.py +43 -29
- components/services/dialogue.py +5 -2
- routes/auth.py +10 -6
- routes/llm.py +7 -1
common/auth.py
CHANGED
|
@@ -11,10 +11,14 @@ SECRET_KEY = os.environ.get("JWT_SECRET", "ooooooh_thats_my_super_secret_key")
|
|
| 11 |
ALGORITHM = "HS256"
|
| 12 |
ACCESS_TOKEN_EXPIRE_MINUTES = 1440
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
# Захардкоженные пользователи
|
| 15 |
USERS = [
|
| 16 |
-
|
| 17 |
-
|
| 18 |
]
|
| 19 |
|
| 20 |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login/token")
|
|
@@ -39,7 +43,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
|
|
| 39 |
username: str = payload.get("sub")
|
| 40 |
if username is None:
|
| 41 |
raise HTTPException(status_code=401, detail="Invalid token")
|
| 42 |
-
user = next((u for u in USERS if u
|
| 43 |
if user is None:
|
| 44 |
raise HTTPException(status_code=401, detail="User not found")
|
| 45 |
return user
|
|
|
|
| 11 |
ALGORITHM = "HS256"
|
| 12 |
ACCESS_TOKEN_EXPIRE_MINUTES = 1440
|
| 13 |
|
| 14 |
+
class User(BaseModel):
|
| 15 |
+
username: str
|
| 16 |
+
password: str
|
| 17 |
+
|
| 18 |
# Захардкоженные пользователи
|
| 19 |
USERS = [
|
| 20 |
+
User(username="admin", password="admin123"),
|
| 21 |
+
User(username="demo", password="sTrUPsORPA")
|
| 22 |
]
|
| 23 |
|
| 24 |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login/token")
|
|
|
|
| 43 |
username: str = payload.get("sub")
|
| 44 |
if username is None:
|
| 45 |
raise HTTPException(status_code=401, detail="Invalid token")
|
| 46 |
+
user = next((u for u in USERS if u.username == username), None)
|
| 47 |
if user is None:
|
| 48 |
raise HTTPException(status_code=401, detail="User not found")
|
| 49 |
return user
|
components/llm/deepinfra_api.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import json
|
| 2 |
from typing import AsyncGenerator, Optional, List
|
| 3 |
import httpx
|
|
@@ -256,13 +257,17 @@ class DeepInfraApi(LlmApi):
|
|
| 256 |
logging.error(f"Request failed: status code {response.status_code}")
|
| 257 |
logging.error(response.text)
|
| 258 |
|
| 259 |
-
async def predict_chat_stream(self, request: ChatRequest, system_prompt, params: LlmPredictParams) -> str:
|
| 260 |
"""
|
| 261 |
Выполняет запрос к API с поддержкой потокового вывода (SSE) и возвращает результат.
|
| 262 |
-
|
| 263 |
Args:
|
| 264 |
-
|
| 265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
Returns:
|
| 267 |
str: Сгенерированный текст.
|
| 268 |
"""
|
|
@@ -270,32 +275,41 @@ class DeepInfraApi(LlmApi):
|
|
| 270 |
request = self.create_chat_request(request, system_prompt, params)
|
| 271 |
request["stream"] = True
|
| 272 |
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
-
return generated_text.strip()
|
| 299 |
|
| 300 |
async def get_predict_chat_generator(self, request: ChatRequest, system_prompt: str,
|
| 301 |
params: LlmPredictParams) -> AsyncGenerator[str, None]:
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
import json
|
| 3 |
from typing import AsyncGenerator, Optional, List
|
| 4 |
import httpx
|
|
|
|
| 257 |
logging.error(f"Request failed: status code {response.status_code}")
|
| 258 |
logging.error(response.text)
|
| 259 |
|
| 260 |
+
async def predict_chat_stream(self, request: ChatRequest, system_prompt, params: LlmPredictParams, max_retries: int = 3, retry_delay: float = 0.5) -> str:
|
| 261 |
"""
|
| 262 |
Выполняет запрос к API с поддержкой потокового вывода (SSE) и возвращает результат.
|
| 263 |
+
|
| 264 |
Args:
|
| 265 |
+
request (ChatRequest): Запрос чата
|
| 266 |
+
system_prompt: Системный промпт
|
| 267 |
+
params (LlmPredictParams): Параметры предсказания
|
| 268 |
+
max_retries (int): Максимальное количество попыток переподключения (по умолчанию 3)
|
| 269 |
+
retry_delay (float): Задержка между попытками в секундах (по умолчанию 0.5)
|
| 270 |
+
|
| 271 |
Returns:
|
| 272 |
str: Сгенерированный текст.
|
| 273 |
"""
|
|
|
|
| 275 |
request = self.create_chat_request(request, system_prompt, params)
|
| 276 |
request["stream"] = True
|
| 277 |
|
| 278 |
+
for attempt in range(max_retries + 1):
|
| 279 |
+
try:
|
| 280 |
+
async with client.stream("POST", f"{self.params.url}/v1/openai/chat/completions",
|
| 281 |
+
json=request,
|
| 282 |
+
headers=super().create_headers()) as response:
|
| 283 |
+
|
| 284 |
+
if response.status_code != 200:
|
| 285 |
+
error_content = await response.aread()
|
| 286 |
+
raise Exception(f"API error: {error_content.decode('utf-8')}")
|
| 287 |
+
|
| 288 |
+
generated_text = ""
|
| 289 |
+
|
| 290 |
+
async for line in response.aiter_lines():
|
| 291 |
+
if line.startswith("data: "):
|
| 292 |
+
try:
|
| 293 |
+
data = json.loads(line[len("data: "):].strip())
|
| 294 |
+
if data == "[DONE]":
|
| 295 |
+
break
|
| 296 |
+
if "choices" in data and data["choices"]:
|
| 297 |
+
token_value = data["choices"][0].get("delta", {}).get("content", "")
|
| 298 |
+
generated_text += token_value
|
| 299 |
+
except json.JSONDecodeError:
|
| 300 |
+
continue
|
| 301 |
+
|
| 302 |
+
return generated_text.strip()
|
| 303 |
+
|
| 304 |
+
except Exception as e:
|
| 305 |
+
if attempt < max_retries:
|
| 306 |
+
# Ждем перед следующей попыткой, если это не последняя попытка
|
| 307 |
+
await asyncio.sleep(retry_delay)
|
| 308 |
+
continue
|
| 309 |
+
else:
|
| 310 |
+
# Если исчерпаны все попытки, пробрасываем исключение
|
| 311 |
+
raise Exception(f"predict_chat_stream failed after {max_retries} retries: {str(e)}")
|
| 312 |
|
|
|
|
| 313 |
|
| 314 |
async def get_predict_chat_generator(self, request: ChatRequest, system_prompt: str,
|
| 315 |
params: LlmPredictParams) -> AsyncGenerator[str, None]:
|
components/services/dialogue.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import logging
|
| 2 |
import os
|
| 3 |
import re
|
| 4 |
-
from typing import List
|
| 5 |
|
| 6 |
from pydantic import BaseModel
|
| 7 |
|
|
@@ -19,6 +19,7 @@ logger = logging.getLogger(__name__)
|
|
| 19 |
class QEResult(BaseModel):
|
| 20 |
use_search: bool
|
| 21 |
search_query: str | None
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
class DialogueService:
|
|
@@ -71,6 +72,7 @@ class DialogueService:
|
|
| 71 |
return QEResult(
|
| 72 |
use_search=from_chat is not None,
|
| 73 |
search_query=from_chat.content if from_chat else None,
|
|
|
|
| 74 |
)
|
| 75 |
|
| 76 |
def get_qe_result_from_chat(self, history: List[Message]) -> QEResult:
|
|
@@ -129,7 +131,8 @@ class DialogueService:
|
|
| 129 |
else:
|
| 130 |
raise ValueError("Первая часть текста должна содержать 'ДА' или 'НЕТ'.")
|
| 131 |
|
| 132 |
-
return QEResult(use_search=bool_var, search_query=second_part
|
|
|
|
| 133 |
|
| 134 |
def _get_search_query(self, history: List[Message]) -> Message | None:
|
| 135 |
"""
|
|
|
|
| 1 |
import logging
|
| 2 |
import os
|
| 3 |
import re
|
| 4 |
+
from typing import List, Optional, Tuple
|
| 5 |
|
| 6 |
from pydantic import BaseModel
|
| 7 |
|
|
|
|
| 19 |
class QEResult(BaseModel):
|
| 20 |
use_search: bool
|
| 21 |
search_query: str | None
|
| 22 |
+
debug_message: Optional[str | None] = ""
|
| 23 |
|
| 24 |
|
| 25 |
class DialogueService:
|
|
|
|
| 72 |
return QEResult(
|
| 73 |
use_search=from_chat is not None,
|
| 74 |
search_query=from_chat.content if from_chat else None,
|
| 75 |
+
debug_message=response
|
| 76 |
)
|
| 77 |
|
| 78 |
def get_qe_result_from_chat(self, history: List[Message]) -> QEResult:
|
|
|
|
| 131 |
else:
|
| 132 |
raise ValueError("Первая часть текста должна содержать 'ДА' или 'НЕТ'.")
|
| 133 |
|
| 134 |
+
return QEResult(use_search=bool_var, search_query=second_part,
|
| 135 |
+
debug_message=input_text)
|
| 136 |
|
| 137 |
def _get_search_query(self, history: List[Message]) -> Message | None:
|
| 138 |
"""
|
routes/auth.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
from typing import Optional
|
| 2 |
-
from fastapi import APIRouter, Body, Form, HTTPException
|
| 3 |
from datetime import timedelta
|
| 4 |
import common.auth as auth
|
| 5 |
|
|
@@ -7,8 +7,8 @@ router = APIRouter(prefix="/auth", tags=["Auth"])
|
|
| 7 |
|
| 8 |
def authenticate_user(username: str, password: str):
|
| 9 |
"""Проверяет, существует ли пользователь и правильный ли пароль."""
|
| 10 |
-
user = next((u for u in auth.USERS if u
|
| 11 |
-
if not user or user
|
| 12 |
raise HTTPException(status_code=401, detail="Неверный логин или пароль")
|
| 13 |
return user
|
| 14 |
|
|
@@ -20,7 +20,7 @@ def generate_access_token(username: str):
|
|
| 20 |
async def login_common(username: str, password: str):
|
| 21 |
"""Общий метод аутентификации."""
|
| 22 |
user = authenticate_user(username, password)
|
| 23 |
-
access_token = generate_access_token(user
|
| 24 |
return {"access_token": access_token, "token_type": "bearer"}
|
| 25 |
|
| 26 |
@router.post("/login", summary="Авторизация через JSON")
|
|
@@ -31,4 +31,8 @@ async def login_json(request: auth.LoginRequest = Body(...)):
|
|
| 31 |
@router.post("/login/token", summary="Авторизация через Form-Data")
|
| 32 |
async def login_form(username: str = Form(...), password: str = Form(...)):
|
| 33 |
"""Принимает Form-Data (x-www-form-urlencoded)."""
|
| 34 |
-
return await login_common(username, password)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Annotated, Optional
|
| 2 |
+
from fastapi import APIRouter, Body, Depends, Form, HTTPException
|
| 3 |
from datetime import timedelta
|
| 4 |
import common.auth as auth
|
| 5 |
|
|
|
|
| 7 |
|
| 8 |
def authenticate_user(username: str, password: str):
|
| 9 |
"""Проверяет, существует ли пользователь и правильный ли пароль."""
|
| 10 |
+
user = next((u for u in auth.USERS if u.username == username), None)
|
| 11 |
+
if not user or user.password != password:
|
| 12 |
raise HTTPException(status_code=401, detail="Неверный логин или пароль")
|
| 13 |
return user
|
| 14 |
|
|
|
|
| 20 |
async def login_common(username: str, password: str):
|
| 21 |
"""Общий метод аутентификации."""
|
| 22 |
user = authenticate_user(username, password)
|
| 23 |
+
access_token = generate_access_token(user.username)
|
| 24 |
return {"access_token": access_token, "token_type": "bearer"}
|
| 25 |
|
| 26 |
@router.post("/login", summary="Авторизация через JSON")
|
|
|
|
| 31 |
@router.post("/login/token", summary="Авторизация через Form-Data")
|
| 32 |
async def login_form(username: str = Form(...), password: str = Form(...)):
|
| 33 |
"""Принимает Form-Data (x-www-form-urlencoded)."""
|
| 34 |
+
return await login_common(username, password)
|
| 35 |
+
|
| 36 |
+
@router.post("/checktoken", summary="Проверяет, аутентифицирован ли пользователь")
|
| 37 |
+
async def check_token(current_user: Annotated[auth.User, Depends(auth.get_current_user)]):
|
| 38 |
+
return {"current_user": current_user.username}
|
routes/llm.py
CHANGED
|
@@ -123,7 +123,13 @@ async def sse_generator(request: ChatRequest, llm_api: DeepInfraApi, system_prom
|
|
| 123 |
"""
|
| 124 |
try:
|
| 125 |
qe_result = await dialogue_service.get_qe_result(request.history)
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
except Exception as e:
|
| 128 |
logger.error(f"Error in SSE chat stream while dialogue_service.get_qe_result: {str(e)}", stack_info=True)
|
| 129 |
yield "data: {\"event\": \"error\", \"data\":\""+str(e)+"\" }\n\n"
|
|
|
|
| 123 |
"""
|
| 124 |
try:
|
| 125 |
qe_result = await dialogue_service.get_qe_result(request.history)
|
| 126 |
+
qe_event = {
|
| 127 |
+
"event": "debug",
|
| 128 |
+
"data": {
|
| 129 |
+
"text": qe_result.debug_message
|
| 130 |
+
}
|
| 131 |
+
}
|
| 132 |
+
yield f"data: {json.dumps(qe_event, ensure_ascii=False)}\n\n"
|
| 133 |
except Exception as e:
|
| 134 |
logger.error(f"Error in SSE chat stream while dialogue_service.get_qe_result: {str(e)}", stack_info=True)
|
| 135 |
yield "data: {\"event\": \"error\", \"data\":\""+str(e)+"\" }\n\n"
|