Spaces:
Sleeping
Sleeping
| import os | |
| from pathlib import Path | |
| import aiosqlite | |
| import gradio as gr | |
| import tiktoken | |
| from openai import AsyncOpenAI | |
| # ========================= | |
| # Конфигурация | |
| # ========================= | |
| DB_DIR = Path("./data") | |
| DB_PATH = DB_DIR / "chats.db" | |
| CHAT_IDS = ["chat1", "chat2", "chat3", "chat4", "chat5"] | |
| # Ограничение входного контекста | |
| INPUT_TOKEN_LIMIT = 4096 | |
| # Ограничение ответа модели | |
| OUTPUT_TOKEN_LIMIT = 1024 | |
| MODEL_NAME = "gpt-4o-mini" | |
| def ensure_db_dir(): | |
| DB_DIR.mkdir(parents=True, exist_ok=True) | |
| # ========================= | |
| # Переменные окружения | |
| # ========================= | |
| authorized_users_str = os.getenv("AUTHORIZED_USERS") | |
| if not authorized_users_str: | |
| raise ValueError("Переменная окружения AUTHORIZED_USERS не установлена") | |
| AUTHORIZED_USERS = dict(user.split(":") for user in authorized_users_str.split(",")) | |
| proctor_links = os.getenv("PROCTOR_LINKS") | |
| if not proctor_links: | |
| raise ValueError("Переменная окружения PROCTOR_LINKS не установлена") | |
| PROCTOR_LINKS = dict(user.split(":", 1) for user in proctor_links.split(",")) | |
| SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "Ты полезный ассистент.") | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| if not OPENAI_API_KEY: | |
| raise ValueError("Переменная окружения OPENAI_API_KEY не установлена") | |
| # ========================= | |
| # OpenAI client | |
| # ========================= | |
| client = AsyncOpenAI(api_key=OPENAI_API_KEY) | |
| # ========================= | |
| # Подсчет токенов и обрезка истории | |
| # ========================= | |
| def get_encoder(model_name: str = MODEL_NAME): | |
| try: | |
| return tiktoken.encoding_for_model(model_name) | |
| except Exception: | |
| return tiktoken.get_encoding("cl100k_base") | |
| def count_message_tokens(messages, model_name: str = MODEL_NAME) -> int: | |
| enc = get_encoder(model_name) | |
| total = 0 | |
| # Приблизительный расчет для chat messages | |
| for msg in messages: | |
| total += 4 | |
| total += len(enc.encode(msg.get("role", ""))) | |
| total += len(enc.encode(msg.get("content", ""))) | |
| total += 2 | |
| return total | |
| def trim_messages_to_limit( | |
| system_prompt: str, | |
| history: list, | |
| new_user_message: str, | |
| model_name: str = MODEL_NAME, | |
| input_limit: int = INPUT_TOKEN_LIMIT, | |
| ): | |
| base_messages = [{"role": "system", "content": system_prompt}] | |
| candidate_history = history[:] if history else [] | |
| final_messages = base_messages + candidate_history + [ | |
| {"role": "user", "content": new_user_message} | |
| ] | |
| while len(final_messages) > 2 and count_message_tokens(final_messages, model_name) > input_limit: | |
| if candidate_history: | |
| candidate_history.pop(0) | |
| else: | |
| break | |
| final_messages = base_messages + candidate_history + [ | |
| {"role": "user", "content": new_user_message} | |
| ] | |
| # Если даже без истории сообщение слишком длинное — режем само сообщение | |
| if count_message_tokens(final_messages, model_name) > input_limit: | |
| enc = get_encoder(model_name) | |
| system_only_tokens = count_message_tokens(base_messages, model_name) | |
| allowed_for_user = max(1, input_limit - system_only_tokens - 20) | |
| trimmed_user_message = enc.decode(enc.encode(new_user_message)[:allowed_for_user]) | |
| final_messages = base_messages + [ | |
| {"role": "user", "content": trimmed_user_message} | |
| ] | |
| return final_messages | |
| # ========================= | |
| # Работа с БД | |
| # ========================= | |
| async def init_db(): | |
| ensure_db_dir() | |
| async with aiosqlite.connect(DB_PATH) as db: | |
| for chat_id in CHAT_IDS: | |
| await db.execute(f""" | |
| CREATE TABLE IF NOT EXISTS {chat_id} ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| username TEXT NOT NULL, | |
| role TEXT NOT NULL, | |
| content TEXT NOT NULL, | |
| timestamp DATETIME DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| await db.commit() | |
| async def save_message(chat_id, username, role, content): | |
| ensure_db_dir() | |
| async with aiosqlite.connect(DB_PATH) as db: | |
| await db.execute( | |
| f"INSERT INTO {chat_id} (username, role, content) VALUES (?, ?, ?)", | |
| (username, role, content), | |
| ) | |
| await db.commit() | |
| async def load_history(chat_id, username): | |
| ensure_db_dir() | |
| async with aiosqlite.connect(DB_PATH) as db: | |
| async with db.execute( | |
| f""" | |
| SELECT role, content | |
| FROM {chat_id} | |
| WHERE username = ? | |
| ORDER BY timestamp ASC, id ASC | |
| """, | |
| (username,), | |
| ) as cursor: | |
| history = [] | |
| async for row in cursor: | |
| history.append({"role": row[0], "content": row[1]}) | |
| return history | |
| async def clear_history(chat_id, username): | |
| ensure_db_dir() | |
| async with aiosqlite.connect(DB_PATH) as db: | |
| await db.execute( | |
| f"DELETE FROM {chat_id} WHERE username = ?", | |
| (username,), | |
| ) | |
| await db.commit() | |
| return [] | |
| # ========================= | |
| # Основная логика | |
| # ========================= | |
| async def send_message(message, history, username, chat_id): | |
| if not username: | |
| error_msg = "Ошибка: пользователь не авторизован." | |
| return "", (history or []) + [{"role": "assistant", "content": error_msg}] | |
| if not chat_id or chat_id not in CHAT_IDS: | |
| error_msg = "Ошибка: выбран некорректный чат." | |
| return "", (history or []) + [{"role": "assistant", "content": error_msg}] | |
| if not message or not message.strip(): | |
| return "", history | |
| history = history or [] | |
| clean_message = message.strip() | |
| await save_message(chat_id, username, "user", clean_message) | |
| if clean_message.lower() == "proctorlink": | |
| if username in PROCTOR_LINKS: | |
| assistant_response = PROCTOR_LINKS[username] | |
| else: | |
| assistant_response = "Вас нет в списке пользователей с доступом к прокторингу." | |
| await save_message(chat_id, username, "assistant", assistant_response) | |
| new_history = history + [ | |
| {"role": "user", "content": clean_message}, | |
| {"role": "assistant", "content": assistant_response}, | |
| ] | |
| return "", new_history | |
| messages = trim_messages_to_limit( | |
| system_prompt=SYSTEM_PROMPT, | |
| history=history, | |
| new_user_message=clean_message, | |
| model_name=MODEL_NAME, | |
| input_limit=INPUT_TOKEN_LIMIT, | |
| ) | |
| try: | |
| response = await client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=messages, | |
| temperature=0.7, | |
| max_tokens=OUTPUT_TOKEN_LIMIT, | |
| timeout=120, | |
| ) | |
| assistant_response = response.choices[0].message.content or "" | |
| except Exception as e: | |
| print(f"OpenAI API error: {e}") | |
| assistant_response = "Error: Could not connect to the AI service." | |
| await save_message(chat_id, username, "assistant", assistant_response) | |
| new_history = history + [ | |
| {"role": "user", "content": clean_message}, | |
| {"role": "assistant", "content": assistant_response}, | |
| ] | |
| return "", new_history | |
| # ========================= | |
| # UI helpers | |
| # ========================= | |
| def set_username(request: gr.Request): | |
| return request.username if request else None | |
| def update_chat_id(selected_chat): | |
| return selected_chat | |
| async def load_chat_history(username, chat_id): | |
| if username and chat_id: | |
| try: | |
| return await load_history(chat_id, username) | |
| except Exception as e: | |
| print(f"Load history error: {e}") | |
| return [{"role": "assistant", "content": "Не удалось загрузить историю чата."}] | |
| return [] | |
| async def handle_clear_history(username, chat_id): | |
| if username and chat_id: | |
| try: | |
| await clear_history(chat_id, username) | |
| except Exception as e: | |
| print(f"Clear history error: {e}") | |
| return [{"role": "assistant", "content": "Не удалось очистить историю чата."}] | |
| return [] | |
| # ========================= | |
| # CSS | |
| # ========================= | |
| custom_css = """ | |
| #chatbot { | |
| width: 100% !important; | |
| max-width: 1200px !important; | |
| margin: 0 auto !important; | |
| height: 65vh !important; | |
| overflow-y: auto; | |
| border: 1px solid #ccc; | |
| border-radius: 8px; | |
| } | |
| #chatbot .message { | |
| font-size: 16px !important; | |
| padding: 10px; | |
| } | |
| #submit_btn, #load_history_btn, #clear_history_btn { | |
| background-color: #4CAF50 !important; | |
| color: white !important; | |
| margin: 5px !important; | |
| border-radius: 5px; | |
| } | |
| #clear_history_btn { | |
| background-color: #ff4444 !important; | |
| } | |
| #msg { | |
| width: 100% !important; | |
| max-width: 1200px !important; | |
| margin: 10px auto !important; | |
| border-radius: 5px; | |
| } | |
| #chat_selector { | |
| margin-bottom: 20px !important; | |
| } | |
| """ | |
| # ========================= | |
| # Интерфейс | |
| # ========================= | |
| with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: | |
| username_state = gr.State() | |
| chat_id_state = gr.State(value="chat1") | |
| chat_selector = gr.Dropdown( | |
| choices=CHAT_IDS, | |
| label="Выберите чат", | |
| value="chat1", | |
| elem_id="chat_selector", | |
| ) | |
| with gr.Column(elem_id="chatbot_container"): | |
| chatbot = gr.Chatbot(type="messages", elem_id="chatbot") | |
| msg = gr.Textbox( | |
| label="Ваше сообщение", | |
| placeholder="Введите ваше сообщение...", | |
| elem_id="msg", | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Отправить", elem_id="submit_btn") | |
| load_history_btn = gr.Button("Загрузить историю", elem_id="load_history_btn") | |
| clear_history_btn = gr.Button("Очистить историю", elem_id="clear_history_btn") | |
| demo.load(init_db, inputs=None, outputs=None) | |
| demo.load(set_username, inputs=None, outputs=username_state) | |
| chat_selector.change(update_chat_id, inputs=chat_selector, outputs=chat_id_state) | |
| chat_id_state.change( | |
| load_chat_history, | |
| inputs=[username_state, chat_id_state], | |
| outputs=chatbot, | |
| ) | |
| username_state.change( | |
| load_chat_history, | |
| inputs=[username_state, chat_id_state], | |
| outputs=chatbot, | |
| ) | |
| load_history_btn.click( | |
| load_chat_history, | |
| inputs=[username_state, chat_id_state], | |
| outputs=chatbot, | |
| ) | |
| clear_history_btn.click( | |
| handle_clear_history, | |
| inputs=[username_state, chat_id_state], | |
| outputs=chatbot, | |
| ) | |
| submit_btn.click( | |
| send_message, | |
| inputs=[msg, chatbot, username_state, chat_id_state], | |
| outputs=[msg, chatbot], | |
| ) | |
| msg.submit( | |
| send_message, | |
| inputs=[msg, chatbot, username_state, chat_id_state], | |
| outputs=[msg, chatbot], | |
| ) | |
| # ========================= | |
| # Запуск | |
| # ========================= | |
| if __name__ == "__main__": | |
| ensure_db_dir() | |
| demo.queue(default_concurrency_limit=40, max_size=50).launch( | |
| auth=[(u, p) for u, p in AUTHORIZED_USERS.items()], | |
| share=True, | |
| ssr_mode=False, | |
| ) |