hugging2021's picture
Upload folder using huggingface_hub
40f6dcf verified
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from abc import ABC, abstractmethod
import logging
logger = logging.getLogger(__name__)
@dataclass
class GeneratedResponse:
"""Response from generation."""
answer: str
confidence: float
sources: List[Dict[str, Any]] = field(default_factory=list)
citations: List[Dict[str, Any]] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=dict)
class BaseGenerator(ABC):
"""Abstract base class for answer generators."""
def __init__(self, config: Optional[Dict[str, Any]] = None):
self.config = config or {}
@abstractmethod
async def generate(
self, query: str, retrieved_chunks: List[Any], **kwargs
) -> GeneratedResponse:
"""Generate an answer based on the query and retrieved context."""
pass
class GroundedGenerator(BaseGenerator):
"""Grounded answer generator with evidence-based generation."""
def __init__(self, config: Optional[Dict[str, Any]] = None):
super().__init__(config)
self.citation_enabled = self.config.get("citation_enabled", True)
self.citation_style = self.config.get("citation_style", "apa")
self.min_confidence = self.config.get("min_confidence", 0.7)
async def generate(
self, query: str, retrieved_chunks: List[Any], **kwargs
) -> GeneratedResponse:
"""Generate an answer grounded in the retrieved context."""
if not retrieved_chunks:
return GeneratedResponse(
answer="I don't have enough information to answer your question.",
confidence=0.0,
)
context = self._build_context(retrieved_chunks)
answer = await self._generate_answer(query, context)
sources = self._extract_sources(retrieved_chunks)
citations = self._generate_citations(sources)
confidence = self._calculate_confidence(retrieved_chunks, answer)
return GeneratedResponse(
answer=answer,
confidence=confidence,
sources=sources,
citations=citations,
metadata={
"chunks_used": len(retrieved_chunks),
"context_length": len(context),
},
)
def _build_context(self, chunks: List[Any]) -> str:
"""Build context string from retrieved chunks."""
context_parts = []
for i, chunk in enumerate(chunks):
source_info = f"[Source {i + 1}]"
if chunk.metadata.get("title"):
source_info += f" ({chunk.metadata['title']})"
context_parts.append(f"{source_info}\n{chunk.content}")
return "\n\n".join(context_parts)
async def _generate_answer(self, query: str, context: str) -> str:
"""Generate answer using LLM with context."""
prompt = self._create_prompt(query, context)
try:
from openai import OpenAI
client = OpenAI()
response = client.chat.completions.create(
model="gpt-4-turbo-preview",
messages=[
{
"role": "system",
"content": "You are a helpful assistant that answers questions based on the provided context. Always cite your sources when providing information.",
},
{"role": "user", "content": prompt},
],
temperature=0.1,
max_tokens=1000,
)
return response.choices[0].message.content or "I couldn't generate an answer."
except ImportError:
return self._fallback_answer(query, context)
except Exception as e:
logging.error(f"Error generating answer: {e}")
return f"Error generating answer: {str(e)}"
def _create_prompt(self, query: str, context: str) -> str:
"""Create the generation prompt."""
return f"""Based on the following context, answer the question. If the answer is not in the context, say so.
Context:
{context}
Question: {query}
Answer:"""
def _fallback_answer(self, query: str, context: str) -> str:
"""Fallback answer generation without LLM."""
return f"Based on the retrieved information, here is what I found regarding '{query}':\n\n{context[:500]}..."
def _extract_sources(self, chunks: List[Any]) -> List[Dict[str, Any]]:
"""Extract source information from chunks."""
sources = []
seen_ids = set()
for chunk in chunks:
source_id = chunk.metadata.get("source") or chunk.document_id
if source_id not in seen_ids:
seen_ids.add(source_id)
sources.append(
{
"id": source_id,
"title": chunk.metadata.get("title", "Unknown"),
"score": chunk.score,
"metadata": chunk.metadata,
}
)
return sources
def _generate_citations(self, sources: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Generate citation information."""
citations = []
for i, source in enumerate(sources):
citation = {
"index": i + 1,
"source_id": source["id"],
"title": source["title"],
"style": self.citation_style,
}
citations.append(citation)
return citations
def _calculate_confidence(self, chunks: List[Any], answer: str) -> float:
"""Calculate confidence score based on retrieved chunks."""
if not chunks:
return 0.0
avg_score = sum(chunk.score for chunk in chunks) / len(chunks)
score = min(avg_score * 1.2, 1.0)
if score < self.min_confidence:
return round(score, 2)
return round(min(score, 1.0), 2)
class OpenAIGenerator(GroundedGenerator):
"""OpenAI-specific generator with additional features."""
def __init__(self, config: Optional[Dict[str, Any]] = None):
super().__init__(config)
self.model = (config or {}).get("model", "gpt-4-turbo-preview")
self.temperature = (config or {}).get("temperature", 0.1)
self.max_tokens = (config or {}).get("max_tokens", 1000)
async def _generate_answer(self, query: str, context: str) -> str:
"""Generate answer using OpenAI."""
prompt = self._create_prompt(query, context)
try:
from openai import OpenAI
client = OpenAI()
response = client.chat.completions.create(
model=self.model,
messages=[
{
"role": "system",
"content": "You are a helpful assistant that answers questions based on the provided context. Always cite your sources when providing information.",
},
{"role": "user", "content": prompt},
],
temperature=self.temperature,
max_tokens=self.max_tokens,
)
return response.choices[0].message.content or "I couldn't generate an answer."
except ImportError:
return self._fallback_answer(query, context)
except Exception as e:
logging.error(f"OpenAI generation error: {e}")
raise