Spaces:
Sleeping
Sleeping
| from typing import List, Dict, Tuple | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| from rank_bm25 import BM25Okapi | |
| from groq import Groq | |
| import gradio as gr | |
| from dataclasses import dataclass | |
| import re | |
| class Chunk: | |
| id: int | |
| text: str | |
| position: int | |
| metadata: Dict = None | |
| def __post_init__(self): | |
| if self.metadata is None: | |
| self.metadata = {} | |
| class DocumentChunker: | |
| def __init__(self, chunk_size: int = 500, overlap: int = 100): | |
| self.chunk_size = chunk_size | |
| self.overlap = overlap | |
| def chunk_text(self, text: str) -> List[Chunk]: | |
| # Розбиття на речення | |
| sentences = re.split(r'[.!?]+', text) | |
| sentences = [s.strip() for s in sentences if s.strip()] | |
| chunks = [] | |
| current_chunk = "" | |
| chunk_id = 0 | |
| for sentence in sentences: | |
| if len(current_chunk) + len(sentence) > self.chunk_size and current_chunk: | |
| chunks.append(Chunk( | |
| id=chunk_id, | |
| text=current_chunk.strip(), | |
| position=chunk_id, | |
| metadata={'sentence_count': len(current_chunk.split('.'))} | |
| )) | |
| # Створення overlap | |
| words = current_chunk.split() | |
| overlap_words = words[-int(self.overlap / 5):] if len(words) > int(self.overlap / 5) else words | |
| current_chunk = ' '.join(overlap_words) + ' ' + sentence | |
| chunk_id += 1 | |
| else: | |
| current_chunk += ' ' + sentence | |
| # Додавання останнього чанка | |
| if current_chunk.strip(): | |
| chunks.append(Chunk( | |
| id=chunk_id, | |
| text=current_chunk.strip(), | |
| position=chunk_id, | |
| metadata={'sentence_count': len(current_chunk.split('.'))} | |
| )) | |
| return chunks | |
| class HybridRetriever: | |
| """Гібридний retriever з BM25 та semantic search""" | |
| def __init__(self, model_name: str = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'): | |
| self.embedding_model = SentenceTransformer(model_name) | |
| self.bm25 = None | |
| self.chunks = [] | |
| self.embeddings = None | |
| self.tokenized_corpus = [] | |
| def index_documents(self, chunks: List[Chunk]): | |
| self.chunks = chunks | |
| texts = [chunk.text for chunk in chunks] | |
| self.tokenized_corpus = [self._tokenize(text) for text in texts] | |
| self.bm25 = BM25Okapi(self.tokenized_corpus) | |
| print("Створення embeddings...") | |
| self.embeddings = self.embedding_model.encode(texts, show_progress_bar=True) | |
| def _tokenize(self, text: str) -> List[str]: | |
| text = text.lower() | |
| text = re.sub(r'[^\wа-яїієґ\s]', ' ', text) | |
| tokens = text.split() | |
| return [t for t in tokens if len(t) > 2] | |
| def bm25_search(self, query: str, top_k: int = 10) -> List[Tuple[Chunk, float]]: | |
| if self.bm25 is None: | |
| return [] | |
| tokenized_query = self._tokenize(query) | |
| scores = self.bm25.get_scores(tokenized_query) | |
| top_indices = np.argsort(scores)[::-1][:top_k] | |
| results = [(self.chunks[i], scores[i]) for i in top_indices] | |
| return results | |
| def semantic_search(self, query: str, top_k: int = 10) -> List[Tuple[Chunk, float]]: | |
| if self.embeddings is None: | |
| return [] | |
| query_embedding = self.embedding_model.encode([query])[0] | |
| # Косинусна подібність | |
| similarities = np.dot(self.embeddings, query_embedding) / ( | |
| np.linalg.norm(self.embeddings, axis=1) * np.linalg.norm(query_embedding) | |
| ) | |
| top_indices = np.argsort(similarities)[::-1][:top_k] | |
| results = [(self.chunks[i], similarities[i]) for i in top_indices] | |
| return results | |
| def hybrid_search(self, query: str, top_k: int = 10, | |
| alpha: float = 0.5) -> List[Tuple[Chunk, float]]: | |
| bm25_results = self.bm25_search(query, top_k * 2) | |
| semantic_results = self.semantic_search(query, top_k * 2) | |
| bm25_scores = {chunk.id: score for chunk, score in bm25_results} | |
| semantic_scores = {chunk.id: score for chunk, score in semantic_results} | |
| combined_scores = {} | |
| all_ids = set(bm25_scores.keys()) | set(semantic_scores.keys()) | |
| for chunk_id in all_ids: | |
| bm25_score = bm25_scores.get(chunk_id, 0) | |
| semantic_score = semantic_scores.get(chunk_id, 0) | |
| if bm25_results: | |
| max_bm25 = max(bm25_scores.values()) | |
| bm25_score = bm25_score / max_bm25 if max_bm25 > 0 else 0 | |
| combined_scores[chunk_id] = alpha * bm25_score + (1 - alpha) * semantic_score | |
| sorted_ids = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)[:top_k] | |
| results = [(next(c for c in self.chunks if c.id == cid), score) | |
| for cid, score in sorted_ids] | |
| return results | |
| class Reranker: | |
| def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2'): | |
| self.model = CrossEncoder(model_name) | |
| def rerank(self, query: str, chunks: List[Tuple[Chunk, float]], | |
| top_k: int = 5) -> List[Tuple[Chunk, float]]: | |
| if not chunks: | |
| return [] | |
| pairs = [[query, chunk.text] for chunk, _ in chunks] | |
| scores = self.model.predict(pairs) | |
| results = list(zip([chunk for chunk, _ in chunks], scores)) | |
| results.sort(key=lambda x: x[1], reverse=True) | |
| return results[:top_k] | |
| class RAGSystem: | |
| def __init__(self, api_key: str = None, model: str = "llama-3.3-70b-versatile"): | |
| self.chunker = DocumentChunker() | |
| self.retriever = HybridRetriever() | |
| self.reranker = Reranker() | |
| self.client = Groq(api_key=api_key) if api_key else None | |
| self.model = model | |
| self.chunks = [] | |
| def load_document(self, text: str) -> str: | |
| # Chunking | |
| self.chunks = self.chunker.chunk_text(text) | |
| # Індексація | |
| self.retriever.index_documents(self.chunks) | |
| return f"Документ успішно завантажено. Створено {len(self.chunks)} чанків." | |
| def answer_question(self, question: str, retrieval_method: str = "hybrid", | |
| use_reranker: bool = True, show_citations: bool = True) -> Tuple[str, List[Dict]]: | |
| """Відповідь на запитання""" | |
| if not self.chunks: | |
| return "Спочатку завантажте документ!", [] | |
| # Retrieval | |
| if retrieval_method == "bm25": | |
| retrieved = self.retriever.bm25_search(question, top_k=10) | |
| elif retrieval_method == "semantic": | |
| retrieved = self.retriever.semantic_search(question, top_k=10) | |
| else: # hybrid | |
| retrieved = self.retriever.hybrid_search(question, top_k=10) | |
| # Reranking | |
| if use_reranker and retrieved: | |
| retrieved = self.reranker.rerank(question, retrieved, top_k=5) | |
| # Генерація відповіді з Groq | |
| if self.client is None: | |
| return "API ключ не налаштовано!", [] | |
| context = "\n\n".join([ | |
| f"[{i + 1}] {chunk.text}" | |
| for i, (chunk, _) in enumerate(retrieved) | |
| ]) | |
| prompt = f"""На основі наведеного контексту дайте відповідь на запитання. | |
| Обов'язково вказуйте номери джерел у квадратних дужках [1], [2] тощо. | |
| Контекст: | |
| {context} | |
| Запитання: {question} | |
| Відповідь (з цитуванням):""" | |
| try: | |
| chat_completion = self.client.chat.completions.create( | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": "Ви - помічник, який відповідає на запитання на основі наданого контексту. Завжди цитуйте джерела у квадратних дужках." | |
| }, | |
| { | |
| "role": "user", | |
| "content": prompt | |
| } | |
| ], | |
| model=self.model, | |
| temperature=0.3, | |
| max_tokens=2048, | |
| ) | |
| answer = chat_completion.choices[0].message.content | |
| except Exception as e: | |
| return f"Помилка при генерації відповіді: {str(e)}", [] | |
| # Формуємо цитування | |
| citations = [] | |
| if show_citations: | |
| citations = [ | |
| {"id": i + 1, "text": chunk.text, "score": float(score)} | |
| for i, (chunk, score) in enumerate(retrieved) | |
| ] | |
| return answer, citations | |
| def create_gradio_interface(): | |
| rag_system = None | |
| def load_file(file, api_key, model): | |
| nonlocal rag_system | |
| if not api_key: | |
| return "Введіть Groq API ключ!", "", "" | |
| try: | |
| filename = file.name.lower() | |
| if filename.endswith(".txt"): | |
| with open(file.name, 'r', encoding='utf-8') as f: | |
| text = f.read() | |
| elif filename.endswith(".pdf"): | |
| import pdfplumber | |
| text = "" | |
| with pdfplumber.open(file.name) as pdf: | |
| for page in pdf.pages: | |
| text += page.extract_text() + "\n" | |
| else: | |
| return "Формат файлу не підтримується! Завантажте .txt або .pdf", "", "" | |
| rag_system = RAGSystem(api_key=api_key, model=model) | |
| status = rag_system.load_document(text) | |
| return status, "", "" | |
| except Exception as e: | |
| return f"Помилка завантаження файлу: {str(e)}", "", "" | |
| def answer(question, retrieval_method, use_reranker, show_citations): | |
| if rag_system is None: | |
| return "Спочатку завантажте документ!", "" | |
| try: | |
| answer_text, citations = rag_system.answer_question( | |
| question, | |
| retrieval_method.lower().replace(" ", ""), | |
| use_reranker, | |
| show_citations | |
| ) | |
| citations_text = "" | |
| if citations and show_citations: | |
| citations_text = "\n\n📚 Джерела:\n\n" | |
| for cit in citations: | |
| citations_text += f"[{cit['id']}] {cit['text'][:200]}...\n" | |
| citations_text += f"Score: {cit['score']:.3f}\n\n" | |
| return answer_text, citations_text | |
| except Exception as e: | |
| return f"Помилка: {str(e)}", "" | |
| # Створення інтерфейсу | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| # ⚡ RAG Question Answering System з Groq API | |
| Швидка система для відповідей на запитання з використанням RAG підходу та Groq LLMs. | |
| Завантажте українську книгу та отримайте відповіді на свої запитання! | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| api_key_input = gr.Textbox( | |
| label="🔑 Groq API Key", | |
| type="password", | |
| placeholder="gsk_...", | |
| info="Отримайте безкоштовний ключ на console.groq.com" | |
| ) | |
| model_select = gr.Dropdown( | |
| label="🤖 Модель Groq", | |
| choices=[ | |
| "llama-3.3-70b-versatile", | |
| "llama-3.1-70b-versatile", | |
| "mixtral-8x7b-32768", | |
| "gemma2-9b-it" | |
| ], | |
| value="llama-3.3-70b-versatile", | |
| info="Llama 3.3 70B рекомендується для кращої якості" | |
| ) | |
| file_input = gr.File( | |
| label="📁 Завантажте книгу (.txt або .pdf)", | |
| file_types=["text", ".txt", ".pdf"] | |
| ) | |
| load_btn = gr.Button("📥 Завантажити документ", variant="primary", size="lg") | |
| status_output = gr.Textbox(label="Статус", interactive=False) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ⚙️ Налаштування пошуку") | |
| retrieval_method = gr.Radio( | |
| ["BM25", "Semantic", "Hybrid"], | |
| label="Метод пошуку", | |
| value="Hybrid", | |
| info="Hybrid комбінує обидва методи" | |
| ) | |
| use_reranker = gr.Checkbox( | |
| label="Використовувати Reranker", | |
| value=True, | |
| info="Покращує точність результатів" | |
| ) | |
| show_citations = gr.Checkbox( | |
| label="Показувати цитування", | |
| value=True, | |
| info="Відображає джерела інформації" | |
| ) | |
| gr.Markdown("---") | |
| question_input = gr.Textbox( | |
| label="❓ Ваше запитання", | |
| placeholder="Введіть запитання про книгу...", | |
| lines=2 | |
| ) | |
| ask_btn = gr.Button("🔍 Знайти відповідь", variant="primary", size="lg") | |
| with gr.Row(): | |
| with gr.Column(): | |
| answer_output = gr.Textbox( | |
| label="💡 Відповідь", | |
| lines=10, | |
| ) | |
| with gr.Column(): | |
| citations_output = gr.Textbox( | |
| label="📚 Джерела", | |
| lines=10, | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| "Про що ця книга?", | |
| "Хто головний герой?", | |
| "Що сталося в кінці?", | |
| "Які основні теми розглядаються?", | |
| ], | |
| inputs=question_input, | |
| label="💭 Приклади запитань" | |
| ) | |
| load_btn.click( | |
| load_file, | |
| inputs=[file_input, api_key_input, model_select], | |
| outputs=[status_output, answer_output, citations_output] | |
| ) | |
| ask_btn.click( | |
| answer, | |
| inputs=[question_input, retrieval_method, use_reranker, show_citations], | |
| outputs=[answer_output, citations_output] | |
| ) | |
| question_input.submit( | |
| answer, | |
| inputs=[question_input, retrieval_method, use_reranker, show_citations], | |
| outputs=[answer_output, citations_output] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| print("Запуск RAG системи з Groq API") | |
| demo = create_gradio_interface() | |
| demo.launch(share=True) | |