| | import os |
| | from typing import List, Dict |
| | from dotenv import dotenv_values |
| | from transformers import pipeline |
| |
|
| | |
| | env_vars = dotenv_values(".env") |
| |
|
| | def get_env(key, default=None): |
| | """Get env var from .env first, then system environment""" |
| | return env_vars.get(key) or os.environ.get(key) or default |
| |
|
| | |
| | genai = None |
| | _gemini_client = None |
| | _api_key = get_env("GEMINI_API_KEY") |
| | try: |
| | from google import genai |
| | from google.genai import types |
| | |
| | |
| | if _api_key: |
| | _gemini_client = genai.Client(api_key=_api_key) |
| | except ImportError: |
| | pass |
| | except Exception as e: |
| | print(f"Warning: Failed to initialize Gemini client. Check API key/configuration. Error: {e}") |
| |
|
| | class LLMReader: |
| | """ |
| | LLM Reader using Google Gemini (via GEMINI_API_KEY from .env or environment) |
| | Falls back to a local small model if unavailable. |
| | """ |
| |
|
| | def __init__(self, provider: str = "gemini"): |
| | self.provider = provider.lower() |
| | |
| | |
| | self.model = get_env("VDOCRAG_LLM_MODEL", "gemini-2.5-flash") |
| | self.api_key = get_env("GEMINI_API_KEY") |
| | self.client = _gemini_client |
| | self.local_pipeline = None |
| |
|
| | print("=" * 50) |
| | print(f"LLMReader Init: Loading GEMINI_API_KEY...") |
| | if self.api_key: |
| | print(f"LLMReader Init: SUCCESS. Key prefix: {self.api_key[:4]}...{self.api_key[-4:]}") |
| | else: |
| | print(f"LLMReader Init: FAILED. GEMINI_API_KEY not found.") |
| | print("=" * 50) |
| |
|
| | if self.provider == "gemini": |
| | |
| | if not self.api_key: |
| | print("⚠️ No GEMINI_API_KEY found, switching to local model.") |
| | self.provider = "local" |
| | elif genai is None: |
| | raise ImportError("Please install the modern Google GenAI SDK: `pip install google-genai`.") |
| | elif self.client is None: |
| | print("⚠️ Failed to initialize Gemini client, switching to local model.") |
| | self.provider = "local" |
| |
|
| | if self.provider == "local": |
| | print(f"Loading local model: distilgpt2...") |
| | self.local_pipeline = pipeline("text-generation", model="distilgpt2") |
| |
|
| | if self.provider not in ("gemini", "local"): |
| | print(f"⚠️ Unknown provider '{self.provider}', defaulting to local.") |
| | self.provider = "local" |
| | if self.local_pipeline is None: |
| | print(f"Loading local model: distilgpt2...") |
| | self.local_pipeline = pipeline("text-generation", model="distilgpt2") |
| |
|
| | |
| | |
| | |
| | def _call_gemini(self, query: str, context: str) -> str: |
| | system_prompt = ( |
| | "You are a precise data analysis assistant. " |
| | "Given the provided CONTEXT, answer the user's QUESTION accurately. " |
| | "If calculations are needed, perform them. " |
| | "Only respond with the final answer and no additional commentary or explanation." |
| | ) |
| |
|
| | user_content = f"CONTEXT:\n---\n{context}\n---\nQUESTION: {query}" |
| |
|
| | try: |
| | config = types.GenerateContentConfig( |
| | system_instruction=system_prompt, |
| | temperature=0.1 |
| | ) |
| | response = self.client.models.generate_content( |
| | model=self.model, |
| | contents=user_content, |
| | config=config |
| | ) |
| | return response.text.strip() |
| | except Exception as e: |
| | return f"[Gemini API Error] {type(e).__name__}: {e}" |
| |
|
| | |
| | |
| | |
| | def _call_local(self, query: str, context: str) -> str: |
| | prompt = ( |
| | f"CONTEXT:\n{context}\n\n" |
| | f"Based on the context, answer the following question:\n" |
| | f"QUESTION: {query}\n" |
| | f"ANSWER:" |
| | ) |
| |
|
| | result = self.local_pipeline( |
| | prompt, |
| | max_new_tokens=100, |
| | do_sample=True, |
| | truncation=True |
| | ) |
| | generated_text = result[0]["generated_text"] |
| | answer = generated_text[len(prompt):].strip() |
| |
|
| | if not answer or context in answer: |
| | return "[Local model failed to generate a new answer and may have repeated the context]" |
| | return answer |
| |
|
| | |
| | |
| | |
| | def answer_question(self, query: str, context: str, sources: List[Dict]) -> Dict: |
| | if self.provider == "gemini": |
| | answer_text = self._call_gemini(query, context) |
| | elif self.provider == "local": |
| | answer_text = self._call_local(query, context) |
| | else: |
| | answer_text = f"[Error: Unknown provider '{self.provider}']" |
| |
|
| | provenance = [ |
| | { |
| | "page": s["metadata"].get("page"), |
| | "text": s["text"][:200], |
| | "score": s.get("score", 0), |
| | } |
| | for s in sources |
| | ] |
| |
|
| | return {"text": answer_text, "sources": provenance} |
| |
|