Spaces:
Sleeping
Sleeping
| """ | |
| Gemini File Search Client | |
| Handles interaction with Google Gemini File Search API for RAG. | |
| """ | |
| import os | |
| import logging | |
| from typing import List, Dict, Any, Optional | |
| from dataclasses import dataclass | |
| logger = logging.getLogger(__name__) | |
| try: | |
| from google import genai | |
| from google.genai import types | |
| GEMINI_AVAILABLE = True | |
| except ImportError: | |
| GEMINI_AVAILABLE = False | |
| class GeminiFileSearchResult: | |
| """Result from Gemini File Search query""" | |
| answer: str | |
| sources: List[Dict[str, Any]] # List of document references | |
| grounding_metadata: Optional[Dict[str, Any]] = None | |
| query: str = "" | |
| class GeminiFileSearchClient: | |
| """Client for interacting with Gemini File Search API""" | |
| def __init__(self, api_key: Optional[str] = None, store_name: Optional[str] = None): | |
| """ | |
| Initialize Gemini File Search client. | |
| Args: | |
| api_key: Gemini API key (defaults to GEMINI_API_KEY env var) | |
| store_name: File search store name (defaults to GEMINI_FILESTORE_NAME env var) | |
| """ | |
| if not GEMINI_AVAILABLE: | |
| raise ImportError("google-genai package not installed. Install with: pip install google-genai") | |
| self.api_key = api_key or os.getenv("GEMINI_API_KEY") | |
| if not self.api_key: | |
| raise ValueError("GEMINI_API_KEY not found. Set it in .env file or pass as argument.") | |
| self.store_name = store_name or os.getenv("GEMINI_FILESTORE_NAME") | |
| if not self.store_name: | |
| raise ValueError("GEMINI_FILESTORE_NAME not found. Set it in .env file or pass as argument.") | |
| self.client = genai.Client(api_key=self.api_key) | |
| self.model = "gemini-2.5-flash" # or "gemini-2.5-pro" | |
| def search( | |
| self, | |
| query: str, | |
| filters: Optional[Dict[str, Any]] = None, | |
| model: Optional[str] = None | |
| ) -> GeminiFileSearchResult: | |
| """ | |
| Search using Gemini File Search. | |
| Args: | |
| query: User query | |
| filters: Optional filters (year, source, district, etc.) | |
| model: Model to use (defaults to gemini-2.5-flash) | |
| Returns: | |
| GeminiFileSearchResult with answer and sources | |
| """ | |
| model = model or self.model | |
| # Build filter context for the query if filters are provided | |
| # Gemini File Search doesn't support explicit filters in the API, | |
| # so we add them as context in the query | |
| filter_context = "" | |
| if filters: | |
| filter_parts = [] | |
| if filters.get("year"): | |
| years = filters["year"] if isinstance(filters["year"], list) else [filters["year"]] | |
| filter_parts.append(f"Year: {', '.join(years)}") | |
| if filters.get("sources"): | |
| sources = filters["sources"] if isinstance(filters["sources"], list) else [filters["sources"]] | |
| filter_parts.append(f"Source: {', '.join(sources)}") | |
| if filters.get("district"): | |
| districts = filters["district"] if isinstance(filters["district"], list) else [filters["district"]] | |
| filter_parts.append(f"District: {', '.join(districts)}") | |
| if filters.get("filenames"): | |
| filenames = filters["filenames"] if isinstance(filters["filenames"], list) else [filters["filenames"]] | |
| filter_parts.append(f"Filename: {', '.join(filenames)}") | |
| if filter_parts: | |
| filter_context = f"\n\nPlease focus on documents matching these criteria: {', '.join(filter_parts)}" | |
| # Combine query with filter context | |
| # Add explicit instruction to only use information from retrieved documents | |
| instruction = "\n\nIMPORTANT: Only use information from the retrieved documents. Do not use information from your training data unless it's explicitly mentioned in the retrieved documents. If the retrieved documents don't contain the requested information, clearly state that.\n\n" | |
| full_query = query + filter_context + instruction | |
| try: | |
| # Generate content with file search | |
| # Based on Gemini API docs: https://ai.google.dev/gemini-api/docs/file-search | |
| try: | |
| # Try the documented format first | |
| response = self.client.models.generate_content( | |
| model=model, | |
| contents=full_query, | |
| config=types.GenerateContentConfig( | |
| tools=[ | |
| types.Tool( | |
| file_search=types.FileSearch( | |
| file_search_store_names=[self.store_name] | |
| ) | |
| ) | |
| ] | |
| ) | |
| ) | |
| except (AttributeError, TypeError) as e: | |
| # Fallback: try alternative format | |
| logger.warning(f"Primary API format failed, trying alternative: {e}") | |
| try: | |
| response = self.client.models.generate_content( | |
| model=model, | |
| contents=full_query, | |
| tools=[{ | |
| "file_search": { | |
| "file_search_store_names": [self.store_name] | |
| } | |
| }] | |
| ) | |
| except Exception as e2: | |
| raise Exception(f"Failed to call Gemini API: {e2}") | |
| # Extract answer | |
| answer = "" | |
| if hasattr(response, 'text'): | |
| answer = response.text | |
| elif hasattr(response, 'candidates') and response.candidates: | |
| # Try to get text from first candidate | |
| candidate = response.candidates[0] | |
| if hasattr(candidate, 'content') and candidate.content: | |
| if hasattr(candidate.content, 'parts'): | |
| text_parts = [] | |
| for part in candidate.content.parts: | |
| if hasattr(part, 'text'): | |
| text_parts.append(part.text) | |
| answer = " ".join(text_parts) | |
| elif isinstance(candidate.content, str): | |
| answer = candidate.content | |
| else: | |
| answer = str(response) | |
| # Extract grounding metadata (document references) | |
| sources = [] | |
| grounding_metadata = None | |
| if hasattr(response, 'candidates') and response.candidates: | |
| candidate = response.candidates[0] | |
| # Get grounding metadata | |
| if hasattr(candidate, 'grounding_metadata'): | |
| grounding_metadata = candidate.grounding_metadata | |
| # Extract source documents from grounding metadata | |
| # Handle different response formats | |
| grounding_chunks = None | |
| if hasattr(grounding_metadata, 'grounding_chunks'): | |
| grounding_chunks = grounding_metadata.grounding_chunks | |
| elif isinstance(grounding_metadata, dict) and 'grounding_chunks' in grounding_metadata: | |
| grounding_chunks = grounding_metadata['grounding_chunks'] | |
| if grounding_chunks: | |
| for chunk in grounding_chunks: | |
| # Handle both object and dict formats | |
| try: | |
| if isinstance(chunk, dict): | |
| chunk_data = chunk | |
| else: | |
| # Object format - convert to dict-like access | |
| chunk_data = {} | |
| if hasattr(chunk, 'chunk'): | |
| chunk_obj = chunk.chunk | |
| chunk_data['chunk'] = { | |
| 'text': getattr(chunk_obj, 'text', ''), | |
| 'file_name': getattr(chunk_obj, 'file_name', '') | |
| } | |
| if hasattr(chunk, 'relevance_score'): | |
| score_obj = chunk.relevance_score | |
| chunk_data['relevance_score'] = { | |
| 'score': getattr(score_obj, 'score', 0.0) | |
| } | |
| chunk_info = chunk_data.get('chunk', {}) | |
| text = chunk_info.get('text', '') if isinstance(chunk_info, dict) else '' | |
| file_name = chunk_info.get('file_name', '') if isinstance(chunk_info, dict) else '' | |
| score_data = chunk_data.get('relevance_score', {}) | |
| score = score_data.get('score', 0.0) if isinstance(score_data, dict) else 0.0 | |
| if text or file_name: # Only add if we have content | |
| source_info = { | |
| "content": text, | |
| "filename": file_name, | |
| "score": score, | |
| } | |
| sources.append(source_info) | |
| except Exception as e: | |
| logger.warning(f"Error extracting chunk info: {e}") | |
| continue | |
| return GeminiFileSearchResult( | |
| answer=answer, | |
| sources=sources, | |
| grounding_metadata=grounding_metadata, | |
| query=query | |
| ) | |
| except Exception as e: | |
| # Return error result | |
| return GeminiFileSearchResult( | |
| answer=f"I apologize, but I encountered an error: {str(e)}", | |
| sources=[], | |
| query=query | |
| ) | |
| def format_sources_for_display(self, result: GeminiFileSearchResult) -> List[Any]: | |
| """ | |
| Format Gemini sources to match the format expected by the UI. | |
| Returns list of document-like objects compatible with existing display code. | |
| """ | |
| from langchain.docstore.document import Document | |
| formatted_sources = [] | |
| for i, source in enumerate(result.sources): | |
| # Create a Document object compatible with existing code | |
| doc = Document( | |
| page_content=source.get("content", ""), | |
| metadata={ | |
| "filename": source.get("filename", "Unknown"), | |
| "source": "Gemini File Search", | |
| "score": source.get("score"), | |
| "chunk_index": i, | |
| # Add default fields that might be expected | |
| "page": None, | |
| "year": None, | |
| "district": None, | |
| } | |
| ) | |
| formatted_sources.append(doc) | |
| return formatted_sources | |