NEOAI-chatbot-2 / app.py
gotheartem's picture
Update app.py
b823769 verified
import os
from pathlib import Path
import aiosqlite
import gradio as gr
import tiktoken
from openai import AsyncOpenAI
# =========================
# Конфигурация
# =========================
DB_DIR = Path("./data")
DB_PATH = DB_DIR / "chats.db"
CHAT_IDS = ["chat1", "chat2", "chat3", "chat4", "chat5"]
# Ограничение входного контекста
INPUT_TOKEN_LIMIT = 4096
# Ограничение ответа модели
OUTPUT_TOKEN_LIMIT = 1024
MODEL_NAME = "gpt-4o-mini"
def ensure_db_dir():
DB_DIR.mkdir(parents=True, exist_ok=True)
# =========================
# Переменные окружения
# =========================
authorized_users_str = os.getenv("AUTHORIZED_USERS")
if not authorized_users_str:
raise ValueError("Переменная окружения AUTHORIZED_USERS не установлена")
AUTHORIZED_USERS = dict(user.split(":") for user in authorized_users_str.split(","))
proctor_links = os.getenv("PROCTOR_LINKS")
if not proctor_links:
raise ValueError("Переменная окружения PROCTOR_LINKS не установлена")
PROCTOR_LINKS = dict(user.split(":", 1) for user in proctor_links.split(","))
SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "Ты полезный ассистент.")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
raise ValueError("Переменная окружения OPENAI_API_KEY не установлена")
# =========================
# OpenAI client
# =========================
client = AsyncOpenAI(api_key=OPENAI_API_KEY)
# =========================
# Подсчет токенов и обрезка истории
# =========================
def get_encoder(model_name: str = MODEL_NAME):
try:
return tiktoken.encoding_for_model(model_name)
except Exception:
return tiktoken.get_encoding("cl100k_base")
def count_message_tokens(messages, model_name: str = MODEL_NAME) -> int:
enc = get_encoder(model_name)
total = 0
# Приблизительный расчет для chat messages
for msg in messages:
total += 4
total += len(enc.encode(msg.get("role", "")))
total += len(enc.encode(msg.get("content", "")))
total += 2
return total
def trim_messages_to_limit(
system_prompt: str,
history: list,
new_user_message: str,
model_name: str = MODEL_NAME,
input_limit: int = INPUT_TOKEN_LIMIT,
):
base_messages = [{"role": "system", "content": system_prompt}]
candidate_history = history[:] if history else []
final_messages = base_messages + candidate_history + [
{"role": "user", "content": new_user_message}
]
while len(final_messages) > 2 and count_message_tokens(final_messages, model_name) > input_limit:
if candidate_history:
candidate_history.pop(0)
else:
break
final_messages = base_messages + candidate_history + [
{"role": "user", "content": new_user_message}
]
# Если даже без истории сообщение слишком длинное — режем само сообщение
if count_message_tokens(final_messages, model_name) > input_limit:
enc = get_encoder(model_name)
system_only_tokens = count_message_tokens(base_messages, model_name)
allowed_for_user = max(1, input_limit - system_only_tokens - 20)
trimmed_user_message = enc.decode(enc.encode(new_user_message)[:allowed_for_user])
final_messages = base_messages + [
{"role": "user", "content": trimmed_user_message}
]
return final_messages
# =========================
# Работа с БД
# =========================
async def init_db():
ensure_db_dir()
async with aiosqlite.connect(DB_PATH) as db:
for chat_id in CHAT_IDS:
await db.execute(f"""
CREATE TABLE IF NOT EXISTS {chat_id} (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
""")
await db.commit()
async def save_message(chat_id, username, role, content):
ensure_db_dir()
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
f"INSERT INTO {chat_id} (username, role, content) VALUES (?, ?, ?)",
(username, role, content),
)
await db.commit()
async def load_history(chat_id, username):
ensure_db_dir()
async with aiosqlite.connect(DB_PATH) as db:
async with db.execute(
f"""
SELECT role, content
FROM {chat_id}
WHERE username = ?
ORDER BY timestamp ASC, id ASC
""",
(username,),
) as cursor:
history = []
async for row in cursor:
history.append({"role": row[0], "content": row[1]})
return history
async def clear_history(chat_id, username):
ensure_db_dir()
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
f"DELETE FROM {chat_id} WHERE username = ?",
(username,),
)
await db.commit()
return []
# =========================
# Основная логика
# =========================
async def send_message(message, history, username, chat_id):
if not username:
error_msg = "Ошибка: пользователь не авторизован."
return "", (history or []) + [{"role": "assistant", "content": error_msg}]
if not chat_id or chat_id not in CHAT_IDS:
error_msg = "Ошибка: выбран некорректный чат."
return "", (history or []) + [{"role": "assistant", "content": error_msg}]
if not message or not message.strip():
return "", history
history = history or []
clean_message = message.strip()
await save_message(chat_id, username, "user", clean_message)
if clean_message.lower() == "proctorlink":
if username in PROCTOR_LINKS:
assistant_response = PROCTOR_LINKS[username]
else:
assistant_response = "Вас нет в списке пользователей с доступом к прокторингу."
await save_message(chat_id, username, "assistant", assistant_response)
new_history = history + [
{"role": "user", "content": clean_message},
{"role": "assistant", "content": assistant_response},
]
return "", new_history
messages = trim_messages_to_limit(
system_prompt=SYSTEM_PROMPT,
history=history,
new_user_message=clean_message,
model_name=MODEL_NAME,
input_limit=INPUT_TOKEN_LIMIT,
)
try:
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
temperature=0.7,
max_tokens=OUTPUT_TOKEN_LIMIT,
timeout=120,
)
assistant_response = response.choices[0].message.content or ""
except Exception as e:
print(f"OpenAI API error: {e}")
assistant_response = "Error: Could not connect to the AI service."
await save_message(chat_id, username, "assistant", assistant_response)
new_history = history + [
{"role": "user", "content": clean_message},
{"role": "assistant", "content": assistant_response},
]
return "", new_history
# =========================
# UI helpers
# =========================
def set_username(request: gr.Request):
return request.username if request else None
def update_chat_id(selected_chat):
return selected_chat
async def load_chat_history(username, chat_id):
if username and chat_id:
try:
return await load_history(chat_id, username)
except Exception as e:
print(f"Load history error: {e}")
return [{"role": "assistant", "content": "Не удалось загрузить историю чата."}]
return []
async def handle_clear_history(username, chat_id):
if username and chat_id:
try:
await clear_history(chat_id, username)
except Exception as e:
print(f"Clear history error: {e}")
return [{"role": "assistant", "content": "Не удалось очистить историю чата."}]
return []
# =========================
# CSS
# =========================
custom_css = """
#chatbot {
width: 100% !important;
max-width: 1200px !important;
margin: 0 auto !important;
height: 65vh !important;
overflow-y: auto;
border: 1px solid #ccc;
border-radius: 8px;
}
#chatbot .message {
font-size: 16px !important;
padding: 10px;
}
#submit_btn, #load_history_btn, #clear_history_btn {
background-color: #4CAF50 !important;
color: white !important;
margin: 5px !important;
border-radius: 5px;
}
#clear_history_btn {
background-color: #ff4444 !important;
}
#msg {
width: 100% !important;
max-width: 1200px !important;
margin: 10px auto !important;
border-radius: 5px;
}
#chat_selector {
margin-bottom: 20px !important;
}
"""
# =========================
# Интерфейс
# =========================
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
username_state = gr.State()
chat_id_state = gr.State(value="chat1")
chat_selector = gr.Dropdown(
choices=CHAT_IDS,
label="Выберите чат",
value="chat1",
elem_id="chat_selector",
)
with gr.Column(elem_id="chatbot_container"):
chatbot = gr.Chatbot(type="messages", elem_id="chatbot")
msg = gr.Textbox(
label="Ваше сообщение",
placeholder="Введите ваше сообщение...",
elem_id="msg",
)
with gr.Row():
submit_btn = gr.Button("Отправить", elem_id="submit_btn")
load_history_btn = gr.Button("Загрузить историю", elem_id="load_history_btn")
clear_history_btn = gr.Button("Очистить историю", elem_id="clear_history_btn")
demo.load(init_db, inputs=None, outputs=None)
demo.load(set_username, inputs=None, outputs=username_state)
chat_selector.change(update_chat_id, inputs=chat_selector, outputs=chat_id_state)
chat_id_state.change(
load_chat_history,
inputs=[username_state, chat_id_state],
outputs=chatbot,
)
username_state.change(
load_chat_history,
inputs=[username_state, chat_id_state],
outputs=chatbot,
)
load_history_btn.click(
load_chat_history,
inputs=[username_state, chat_id_state],
outputs=chatbot,
)
clear_history_btn.click(
handle_clear_history,
inputs=[username_state, chat_id_state],
outputs=chatbot,
)
submit_btn.click(
send_message,
inputs=[msg, chatbot, username_state, chat_id_state],
outputs=[msg, chatbot],
)
msg.submit(
send_message,
inputs=[msg, chatbot, username_state, chat_id_state],
outputs=[msg, chatbot],
)
# =========================
# Запуск
# =========================
if __name__ == "__main__":
ensure_db_dir()
demo.queue(default_concurrency_limit=40, max_size=50).launch(
auth=[(u, p) for u, p in AUTHORIZED_USERS.items()],
share=True,
ssr_mode=False,
)