Spaces:
Sleeping
Sleeping
File size: 3,685 Bytes
edac567 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import sys
import os
from typing import List, Dict, Optional, Iterator
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
from src.services.embedding_service import EmbeddingService
from src.services.gemini_client import GeminiClient
class RAGService:
"""Combine retrieval + generation workflow."""
def __init__(self, embedding_service: EmbeddingService, gemini_client: GeminiClient) -> None:
"""
Init RAG service.
Args:
embedding_service: EmbeddingService instance.
gemini_client: GeminiClient instance.
"""
self.embedding_service = embedding_service
self.gemini_client = gemini_client
def get_response(self, user_query: str, pdf_id: str, chat_history: Optional[List[Dict]] = None) -> str:
"""
Retrieve context & generate answer.
Args:
user_query: User question.
pdf_id: PDF identifier.
chat_history: Prior messages.
Returns:
Assistant answer text.
"""
chunks = self.embedding_service.find_similar_chunks(user_query, pdf_id=pdf_id, top_k=3)
context = self._format_context(chunks)
return self.gemini_client.generate_response(user_query, context=context, chat_history=chat_history)
def stream_response(self, user_query: str, pdf_id: str, chat_history: Optional[List[Dict]] = None) -> Iterator[str]:
"""
Retrieve context then stream model output.
"""
chunks = self.embedding_service.find_similar_chunks(user_query, pdf_id=pdf_id, top_k=3)
context = self._format_context(chunks)
return self.gemini_client.stream_response(user_query, context=context, chat_history=chat_history)
def _format_context(self, chunks: List[Dict]) -> str:
"""
Format retrieved chunks for prompt.
Args:
chunks: Retrieval result list.
Returns:
Joined context string.
"""
if not chunks:
return ""
lines: List[str] = []
for idx, c in enumerate(chunks, start=1):
if c.get("similarity", 0) > 0.05:
lines.append(f"[Chunk {idx} sim={c['similarity']:.2f}]\n{c.get('text','')}")
return "\n\n".join(lines)
def retrieve_relevant_chunks(self, user_prompt: str, pdf_id: str, top_k: int = 3) -> List[Dict]:
"""
Retrieve relevant chunks based on user prompt
"""
return self.embedding_service.find_similar_chunks(
query=user_prompt,
pdf_id=pdf_id,
top_k=top_k
)
def generate_response_with_sources(self, user_query: str, pdf_id: str, chat_history: List[Dict] = None) -> Dict:
"""
Generate response with source information
"""
try:
# Retrieve relevant chunks
relevant_chunks = self.retrieve_relevant_chunks(user_query, pdf_id)
# Prepare context
context = self._format_context(relevant_chunks)
# Generate response
response = self.gemini_client.generate_response(
prompt=user_query,
context=context,
chat_history=chat_history
)
return {
"response": response,
"sources": relevant_chunks,
"context_used": context
}
except Exception as e:
return {
"response": f"Sorry, I encountered an error: {str(e)}",
"sources": [],
"context_used": ""
} |