rag / app.py
Marriam855's picture
Upload 2 files
08f7d21 verified
from typing import List, Dict, Tuple
import numpy as np
from sentence_transformers import SentenceTransformer, CrossEncoder
from rank_bm25 import BM25Okapi
from groq import Groq
import gradio as gr
from dataclasses import dataclass
import re
@dataclass
class Chunk:
id: int
text: str
position: int
metadata: Dict = None
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
class DocumentChunker:
def __init__(self, chunk_size: int = 500, overlap: int = 100):
self.chunk_size = chunk_size
self.overlap = overlap
def chunk_text(self, text: str) -> List[Chunk]:
# Розбиття на речення
sentences = re.split(r'[.!?]+', text)
sentences = [s.strip() for s in sentences if s.strip()]
chunks = []
current_chunk = ""
chunk_id = 0
for sentence in sentences:
if len(current_chunk) + len(sentence) > self.chunk_size and current_chunk:
chunks.append(Chunk(
id=chunk_id,
text=current_chunk.strip(),
position=chunk_id,
metadata={'sentence_count': len(current_chunk.split('.'))}
))
# Створення overlap
words = current_chunk.split()
overlap_words = words[-int(self.overlap / 5):] if len(words) > int(self.overlap / 5) else words
current_chunk = ' '.join(overlap_words) + ' ' + sentence
chunk_id += 1
else:
current_chunk += ' ' + sentence
# Додавання останнього чанка
if current_chunk.strip():
chunks.append(Chunk(
id=chunk_id,
text=current_chunk.strip(),
position=chunk_id,
metadata={'sentence_count': len(current_chunk.split('.'))}
))
return chunks
class HybridRetriever:
"""Гібридний retriever з BM25 та semantic search"""
def __init__(self, model_name: str = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'):
self.embedding_model = SentenceTransformer(model_name)
self.bm25 = None
self.chunks = []
self.embeddings = None
self.tokenized_corpus = []
def index_documents(self, chunks: List[Chunk]):
self.chunks = chunks
texts = [chunk.text for chunk in chunks]
self.tokenized_corpus = [self._tokenize(text) for text in texts]
self.bm25 = BM25Okapi(self.tokenized_corpus)
print("Створення embeddings...")
self.embeddings = self.embedding_model.encode(texts, show_progress_bar=True)
def _tokenize(self, text: str) -> List[str]:
text = text.lower()
text = re.sub(r'[^\wа-яїієґ\s]', ' ', text)
tokens = text.split()
return [t for t in tokens if len(t) > 2]
def bm25_search(self, query: str, top_k: int = 10) -> List[Tuple[Chunk, float]]:
if self.bm25 is None:
return []
tokenized_query = self._tokenize(query)
scores = self.bm25.get_scores(tokenized_query)
top_indices = np.argsort(scores)[::-1][:top_k]
results = [(self.chunks[i], scores[i]) for i in top_indices]
return results
def semantic_search(self, query: str, top_k: int = 10) -> List[Tuple[Chunk, float]]:
if self.embeddings is None:
return []
query_embedding = self.embedding_model.encode([query])[0]
# Косинусна подібність
similarities = np.dot(self.embeddings, query_embedding) / (
np.linalg.norm(self.embeddings, axis=1) * np.linalg.norm(query_embedding)
)
top_indices = np.argsort(similarities)[::-1][:top_k]
results = [(self.chunks[i], similarities[i]) for i in top_indices]
return results
def hybrid_search(self, query: str, top_k: int = 10,
alpha: float = 0.5) -> List[Tuple[Chunk, float]]:
bm25_results = self.bm25_search(query, top_k * 2)
semantic_results = self.semantic_search(query, top_k * 2)
bm25_scores = {chunk.id: score for chunk, score in bm25_results}
semantic_scores = {chunk.id: score for chunk, score in semantic_results}
combined_scores = {}
all_ids = set(bm25_scores.keys()) | set(semantic_scores.keys())
for chunk_id in all_ids:
bm25_score = bm25_scores.get(chunk_id, 0)
semantic_score = semantic_scores.get(chunk_id, 0)
if bm25_results:
max_bm25 = max(bm25_scores.values())
bm25_score = bm25_score / max_bm25 if max_bm25 > 0 else 0
combined_scores[chunk_id] = alpha * bm25_score + (1 - alpha) * semantic_score
sorted_ids = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
results = [(next(c for c in self.chunks if c.id == cid), score)
for cid, score in sorted_ids]
return results
class Reranker:
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2'):
self.model = CrossEncoder(model_name)
def rerank(self, query: str, chunks: List[Tuple[Chunk, float]],
top_k: int = 5) -> List[Tuple[Chunk, float]]:
if not chunks:
return []
pairs = [[query, chunk.text] for chunk, _ in chunks]
scores = self.model.predict(pairs)
results = list(zip([chunk for chunk, _ in chunks], scores))
results.sort(key=lambda x: x[1], reverse=True)
return results[:top_k]
class RAGSystem:
def __init__(self, api_key: str = None, model: str = "llama-3.3-70b-versatile"):
self.chunker = DocumentChunker()
self.retriever = HybridRetriever()
self.reranker = Reranker()
self.client = Groq(api_key=api_key) if api_key else None
self.model = model
self.chunks = []
def load_document(self, text: str) -> str:
# Chunking
self.chunks = self.chunker.chunk_text(text)
# Індексація
self.retriever.index_documents(self.chunks)
return f"Документ успішно завантажено. Створено {len(self.chunks)} чанків."
def answer_question(self, question: str, retrieval_method: str = "hybrid",
use_reranker: bool = True, show_citations: bool = True) -> Tuple[str, List[Dict]]:
"""Відповідь на запитання"""
if not self.chunks:
return "Спочатку завантажте документ!", []
# Retrieval
if retrieval_method == "bm25":
retrieved = self.retriever.bm25_search(question, top_k=10)
elif retrieval_method == "semantic":
retrieved = self.retriever.semantic_search(question, top_k=10)
else: # hybrid
retrieved = self.retriever.hybrid_search(question, top_k=10)
# Reranking
if use_reranker and retrieved:
retrieved = self.reranker.rerank(question, retrieved, top_k=5)
# Генерація відповіді з Groq
if self.client is None:
return "API ключ не налаштовано!", []
context = "\n\n".join([
f"[{i + 1}] {chunk.text}"
for i, (chunk, _) in enumerate(retrieved)
])
prompt = f"""На основі наведеного контексту дайте відповідь на запитання.
Обов'язково вказуйте номери джерел у квадратних дужках [1], [2] тощо.
Контекст:
{context}
Запитання: {question}
Відповідь (з цитуванням):"""
try:
chat_completion = self.client.chat.completions.create(
messages=[
{
"role": "system",
"content": "Ви - помічник, який відповідає на запитання на основі наданого контексту. Завжди цитуйте джерела у квадратних дужках."
},
{
"role": "user",
"content": prompt
}
],
model=self.model,
temperature=0.3,
max_tokens=2048,
)
answer = chat_completion.choices[0].message.content
except Exception as e:
return f"Помилка при генерації відповіді: {str(e)}", []
# Формуємо цитування
citations = []
if show_citations:
citations = [
{"id": i + 1, "text": chunk.text, "score": float(score)}
for i, (chunk, score) in enumerate(retrieved)
]
return answer, citations
def create_gradio_interface():
rag_system = None
def load_file(file, api_key, model):
nonlocal rag_system
if not api_key:
return "Введіть Groq API ключ!", "", ""
try:
filename = file.name.lower()
if filename.endswith(".txt"):
with open(file.name, 'r', encoding='utf-8') as f:
text = f.read()
elif filename.endswith(".pdf"):
import pdfplumber
text = ""
with pdfplumber.open(file.name) as pdf:
for page in pdf.pages:
text += page.extract_text() + "\n"
else:
return "Формат файлу не підтримується! Завантажте .txt або .pdf", "", ""
rag_system = RAGSystem(api_key=api_key, model=model)
status = rag_system.load_document(text)
return status, "", ""
except Exception as e:
return f"Помилка завантаження файлу: {str(e)}", "", ""
def answer(question, retrieval_method, use_reranker, show_citations):
if rag_system is None:
return "Спочатку завантажте документ!", ""
try:
answer_text, citations = rag_system.answer_question(
question,
retrieval_method.lower().replace(" ", ""),
use_reranker,
show_citations
)
citations_text = ""
if citations and show_citations:
citations_text = "\n\n📚 Джерела:\n\n"
for cit in citations:
citations_text += f"[{cit['id']}] {cit['text'][:200]}...\n"
citations_text += f"Score: {cit['score']:.3f}\n\n"
return answer_text, citations_text
except Exception as e:
return f"Помилка: {str(e)}", ""
# Створення інтерфейсу
with gr.Blocks() as demo:
gr.Markdown("""
# ⚡ RAG Question Answering System з Groq API
Швидка система для відповідей на запитання з використанням RAG підходу та Groq LLMs.
Завантажте українську книгу та отримайте відповіді на свої запитання!
""")
with gr.Row():
with gr.Column(scale=2):
api_key_input = gr.Textbox(
label="🔑 Groq API Key",
type="password",
placeholder="gsk_...",
info="Отримайте безкоштовний ключ на console.groq.com"
)
model_select = gr.Dropdown(
label="🤖 Модель Groq",
choices=[
"llama-3.3-70b-versatile",
"llama-3.1-70b-versatile",
"mixtral-8x7b-32768",
"gemma2-9b-it"
],
value="llama-3.3-70b-versatile",
info="Llama 3.3 70B рекомендується для кращої якості"
)
file_input = gr.File(
label="📁 Завантажте книгу (.txt або .pdf)",
file_types=["text", ".txt", ".pdf"]
)
load_btn = gr.Button("📥 Завантажити документ", variant="primary", size="lg")
status_output = gr.Textbox(label="Статус", interactive=False)
with gr.Column(scale=1):
gr.Markdown("### ⚙️ Налаштування пошуку")
retrieval_method = gr.Radio(
["BM25", "Semantic", "Hybrid"],
label="Метод пошуку",
value="Hybrid",
info="Hybrid комбінує обидва методи"
)
use_reranker = gr.Checkbox(
label="Використовувати Reranker",
value=True,
info="Покращує точність результатів"
)
show_citations = gr.Checkbox(
label="Показувати цитування",
value=True,
info="Відображає джерела інформації"
)
gr.Markdown("---")
question_input = gr.Textbox(
label="❓ Ваше запитання",
placeholder="Введіть запитання про книгу...",
lines=2
)
ask_btn = gr.Button("🔍 Знайти відповідь", variant="primary", size="lg")
with gr.Row():
with gr.Column():
answer_output = gr.Textbox(
label="💡 Відповідь",
lines=10,
)
with gr.Column():
citations_output = gr.Textbox(
label="📚 Джерела",
lines=10,
)
gr.Examples(
examples=[
"Про що ця книга?",
"Хто головний герой?",
"Що сталося в кінці?",
"Які основні теми розглядаються?",
],
inputs=question_input,
label="💭 Приклади запитань"
)
load_btn.click(
load_file,
inputs=[file_input, api_key_input, model_select],
outputs=[status_output, answer_output, citations_output]
)
ask_btn.click(
answer,
inputs=[question_input, retrieval_method, use_reranker, show_citations],
outputs=[answer_output, citations_output]
)
question_input.submit(
answer,
inputs=[question_input, retrieval_method, use_reranker, show_citations],
outputs=[answer_output, citations_output]
)
return demo
if __name__ == "__main__":
print("Запуск RAG системи з Groq API")
demo = create_gradio_interface()
demo.launch(share=True)