Spaces:
Build error
Build error
| from dataclasses import dataclass, field | |
| from typing import Any, Dict, List, Optional | |
| from abc import ABC, abstractmethod | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class GeneratedResponse: | |
| """Response from generation.""" | |
| answer: str | |
| confidence: float | |
| sources: List[Dict[str, Any]] = field(default_factory=list) | |
| citations: List[Dict[str, Any]] = field(default_factory=list) | |
| metadata: Dict[str, Any] = field(default_factory=dict) | |
| class BaseGenerator(ABC): | |
| """Abstract base class for answer generators.""" | |
| def __init__(self, config: Optional[Dict[str, Any]] = None): | |
| self.config = config or {} | |
| async def generate( | |
| self, query: str, retrieved_chunks: List[Any], **kwargs | |
| ) -> GeneratedResponse: | |
| """Generate an answer based on the query and retrieved context.""" | |
| pass | |
| class GroundedGenerator(BaseGenerator): | |
| """Grounded answer generator with evidence-based generation.""" | |
| def __init__(self, config: Optional[Dict[str, Any]] = None): | |
| super().__init__(config) | |
| self.citation_enabled = self.config.get("citation_enabled", True) | |
| self.citation_style = self.config.get("citation_style", "apa") | |
| self.min_confidence = self.config.get("min_confidence", 0.7) | |
| async def generate( | |
| self, query: str, retrieved_chunks: List[Any], **kwargs | |
| ) -> GeneratedResponse: | |
| """Generate an answer grounded in the retrieved context.""" | |
| if not retrieved_chunks: | |
| return GeneratedResponse( | |
| answer="I don't have enough information to answer your question.", | |
| confidence=0.0, | |
| ) | |
| context = self._build_context(retrieved_chunks) | |
| answer = await self._generate_answer(query, context) | |
| sources = self._extract_sources(retrieved_chunks) | |
| citations = self._generate_citations(sources) | |
| confidence = self._calculate_confidence(retrieved_chunks, answer) | |
| return GeneratedResponse( | |
| answer=answer, | |
| confidence=confidence, | |
| sources=sources, | |
| citations=citations, | |
| metadata={ | |
| "chunks_used": len(retrieved_chunks), | |
| "context_length": len(context), | |
| }, | |
| ) | |
| def _build_context(self, chunks: List[Any]) -> str: | |
| """Build context string from retrieved chunks.""" | |
| context_parts = [] | |
| for i, chunk in enumerate(chunks): | |
| source_info = f"[Source {i + 1}]" | |
| if chunk.metadata.get("title"): | |
| source_info += f" ({chunk.metadata['title']})" | |
| context_parts.append(f"{source_info}\n{chunk.content}") | |
| return "\n\n".join(context_parts) | |
| async def _generate_answer(self, query: str, context: str) -> str: | |
| """Generate answer using LLM with context.""" | |
| prompt = self._create_prompt(query, context) | |
| try: | |
| from openai import OpenAI | |
| client = OpenAI() | |
| response = client.chat.completions.create( | |
| model="gpt-4-turbo-preview", | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": "You are a helpful assistant that answers questions based on the provided context. Always cite your sources when providing information.", | |
| }, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| temperature=0.1, | |
| max_tokens=1000, | |
| ) | |
| return response.choices[0].message.content or "I couldn't generate an answer." | |
| except ImportError: | |
| return self._fallback_answer(query, context) | |
| except Exception as e: | |
| logging.error(f"Error generating answer: {e}") | |
| return f"Error generating answer: {str(e)}" | |
| def _create_prompt(self, query: str, context: str) -> str: | |
| """Create the generation prompt.""" | |
| return f"""Based on the following context, answer the question. If the answer is not in the context, say so. | |
| Context: | |
| {context} | |
| Question: {query} | |
| Answer:""" | |
| def _fallback_answer(self, query: str, context: str) -> str: | |
| """Fallback answer generation without LLM.""" | |
| return f"Based on the retrieved information, here is what I found regarding '{query}':\n\n{context[:500]}..." | |
| def _extract_sources(self, chunks: List[Any]) -> List[Dict[str, Any]]: | |
| """Extract source information from chunks.""" | |
| sources = [] | |
| seen_ids = set() | |
| for chunk in chunks: | |
| source_id = chunk.metadata.get("source") or chunk.document_id | |
| if source_id not in seen_ids: | |
| seen_ids.add(source_id) | |
| sources.append( | |
| { | |
| "id": source_id, | |
| "title": chunk.metadata.get("title", "Unknown"), | |
| "score": chunk.score, | |
| "metadata": chunk.metadata, | |
| } | |
| ) | |
| return sources | |
| def _generate_citations(self, sources: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| """Generate citation information.""" | |
| citations = [] | |
| for i, source in enumerate(sources): | |
| citation = { | |
| "index": i + 1, | |
| "source_id": source["id"], | |
| "title": source["title"], | |
| "style": self.citation_style, | |
| } | |
| citations.append(citation) | |
| return citations | |
| def _calculate_confidence(self, chunks: List[Any], answer: str) -> float: | |
| """Calculate confidence score based on retrieved chunks.""" | |
| if not chunks: | |
| return 0.0 | |
| avg_score = sum(chunk.score for chunk in chunks) / len(chunks) | |
| score = min(avg_score * 1.2, 1.0) | |
| if score < self.min_confidence: | |
| return round(score, 2) | |
| return round(min(score, 1.0), 2) | |
| class OpenAIGenerator(GroundedGenerator): | |
| """OpenAI-specific generator with additional features.""" | |
| def __init__(self, config: Optional[Dict[str, Any]] = None): | |
| super().__init__(config) | |
| self.model = (config or {}).get("model", "gpt-4-turbo-preview") | |
| self.temperature = (config or {}).get("temperature", 0.1) | |
| self.max_tokens = (config or {}).get("max_tokens", 1000) | |
| async def _generate_answer(self, query: str, context: str) -> str: | |
| """Generate answer using OpenAI.""" | |
| prompt = self._create_prompt(query, context) | |
| try: | |
| from openai import OpenAI | |
| client = OpenAI() | |
| response = client.chat.completions.create( | |
| model=self.model, | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": "You are a helpful assistant that answers questions based on the provided context. Always cite your sources when providing information.", | |
| }, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| temperature=self.temperature, | |
| max_tokens=self.max_tokens, | |
| ) | |
| return response.choices[0].message.content or "I couldn't generate an answer." | |
| except ImportError: | |
| return self._fallback_answer(query, context) | |
| except Exception as e: | |
| logging.error(f"OpenAI generation error: {e}") | |
| raise | |