Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from sentence_transformers import SentenceTransformer | |
| import chromadb | |
| from chromadb.utils import embedding_functions | |
| import os | |
| from langdetect import detect | |
| # Проверяем наличие текстовых файлов и читаем их | |
| def load_text_files(): | |
| files = { | |
| "vampires": "vampires.txt", | |
| "werewolves": "werewolves.txt", | |
| "humans": "humans.txt" | |
| } | |
| loaded_data = {} | |
| for key, filename in files.items(): | |
| try: | |
| with open(filename, 'r', encoding='utf-8') as file: | |
| loaded_data[key] = file.read() | |
| except FileNotFoundError: | |
| print(f"Файл {filename} не найден") | |
| loaded_data[key] = "" | |
| return loaded_data | |
| # Инициализация модели для эмбеддингов | |
| def initialize_embedding_model(): | |
| return embedding_functions.SentenceTransformerEmbeddingFunction( | |
| model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" | |
| ) | |
| # Создание базы знаний | |
| def create_knowledge_base(text_data, embed_fn): | |
| client = chromadb.Client() | |
| try: | |
| collection = client.get_collection(name="knowledge_base") | |
| except: | |
| collection = client.create_collection( | |
| name="knowledge_base", | |
| embedding_function=embed_fn | |
| ) | |
| # Добавляем документы в коллекцию | |
| documents = [] | |
| metadatas = [] | |
| ids = [] | |
| for category, text in text_data.items(): | |
| if text: # только если текст не пустой | |
| # Разбиваем текст на предложения или абзацы | |
| paragraphs = [p for p in text.split('\n') if p.strip()] | |
| for i, paragraph in enumerate(paragraphs): | |
| documents.append(paragraph) | |
| metadatas.append({"category": category}) | |
| ids.append(f"{category}_{i}") | |
| if documents: | |
| collection.add( | |
| documents=documents, | |
| metadatas=metadatas, | |
| ids=ids | |
| ) | |
| return collection | |
| # Инициализация модели для ответов | |
| def initialize_llm_model(): | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| model_name = "IlyaGusev/saiga_mistral_7b" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device="cpu" | |
| ) | |
| return pipe | |
| # Поиск релевантной информации | |
| def find_relevant_info(question, collection, embed_fn, n_results=3): | |
| results = collection.query( | |
| query_texts=[question], | |
| n_results=n_results | |
| ) | |
| context = "\n\n".join(results['documents'][0]) | |
| return context | |
| # Генерация ответа | |
| def generate_response(question, context, llm_pipe): | |
| system_prompt = """Ты - помощник, который отвечает на вопросы пользователя, используя предоставленную информацию. | |
| Отвечай только на основе предоставленного контекста. Если ответа нет в контексте, скажи, что не знаешь. | |
| Отвечай на русском языке.""" | |
| prompt = f"""<s>{system_prompt} | |
| Контекст: {context} | |
| Вопрос: {question} | |
| Ответ:""" | |
| output = llm_pipe( | |
| prompt, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| repetition_penalty=1.2, | |
| eos_token_id=2 | |
| ) | |
| return output[0]["generated_text"][len(prompt):].strip() | |
| # Основная функция для обработки запросов | |
| def answer_question(question, history): | |
| # Определяем язык вопроса | |
| try: | |
| lang = detect(question) | |
| if lang != 'ru': | |
| return "Пожалуйста, задавайте вопросы на русском языке." | |
| except: | |
| pass | |
| # Загружаем данные (если еще не загружены) | |
| if not hasattr(answer_question, 'text_data'): | |
| answer_question.text_data = load_text_files() | |
| # Инициализируем модели (если еще не инициализированы) | |
| if not hasattr(answer_question, 'embed_fn'): | |
| answer_question.embed_fn = initialize_embedding_model() | |
| if not hasattr(answer_question, 'collection'): | |
| answer_question.collection = create_knowledge_base(answer_question.text_data, answer_question.embed_fn) | |
| if not hasattr(answer_question, 'llm_pipe'): | |
| answer_question.llm_pipe = initialize_llm_model() | |
| # Находим релевантный контекст | |
| context = find_relevant_info(question, answer_question.collection, answer_question.embed_fn) | |
| # Генерируем ответ | |
| response = generate_response(question, context, answer_question.llm_pipe) | |
| return response | |
| # Создаем интерфейс Gradio | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Чат-бот с доступом к текстовым файлам") | |
| gr.Markdown("Задавайте вопросы о вампирах, оборотнях или людях на русском языке") | |
| chatbot = gr.Chatbot(label="Диалог") | |
| msg = gr.Textbox(label="Ваш вопрос") | |
| clear = gr.Button("Очистить") | |
| def respond(message, chat_history): | |
| bot_message = answer_question(message, chat_history) | |
| chat_history.append((message, bot_message)) | |
| return "", chat_history | |
| msg.submit(respond, [msg, chatbot], [msg, chatbot]) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| demo.launch() |