File size: 5,499 Bytes
4e3cee0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | import os
from typing import List, Dict
from dotenv import dotenv_values
from transformers import pipeline
# --- Load from .env first, then fall back to system environment (for cloud deployment) ---
env_vars = dotenv_values(".env") # returns a dict from the .env file
def get_env(key, default=None):
"""Get env var from .env first, then system environment"""
return env_vars.get(key) or os.environ.get(key) or default
# --- FIX: Use the correct modern SDK import (google-genai) and initialize client ---
genai = None
_gemini_client = None
_api_key = get_env("GEMINI_API_KEY")
try:
from google import genai
from google.genai import types
# Initialize client - check both .env and system env for cloud deployment
if _api_key:
_gemini_client = genai.Client(api_key=_api_key)
except ImportError:
pass
except Exception as e:
print(f"Warning: Failed to initialize Gemini client. Check API key/configuration. Error: {e}")
class LLMReader:
"""
LLM Reader using Google Gemini (via GEMINI_API_KEY from .env or environment)
Falls back to a local small model if unavailable.
"""
def __init__(self, provider: str = "gemini"):
self.provider = provider.lower()
# Load from .env or system environment (for cloud deployment)
self.model = get_env("VDOCRAG_LLM_MODEL", "gemini-2.5-flash")
self.api_key = get_env("GEMINI_API_KEY")
self.client = _gemini_client
self.local_pipeline = None
print("=" * 50)
print(f"LLMReader Init: Loading GEMINI_API_KEY...")
if self.api_key:
print(f"LLMReader Init: SUCCESS. Key prefix: {self.api_key[:4]}...{self.api_key[-4:]}")
else:
print(f"LLMReader Init: FAILED. GEMINI_API_KEY not found.")
print("=" * 50)
if self.provider == "gemini":
# Check for API key first - if missing, fall back to local
if not self.api_key:
print("⚠️ No GEMINI_API_KEY found, switching to local model.")
self.provider = "local"
elif genai is None:
raise ImportError("Please install the modern Google GenAI SDK: `pip install google-genai`.")
elif self.client is None:
print("⚠️ Failed to initialize Gemini client, switching to local model.")
self.provider = "local"
if self.provider == "local":
print(f"Loading local model: distilgpt2...")
self.local_pipeline = pipeline("text-generation", model="distilgpt2")
if self.provider not in ("gemini", "local"):
print(f"⚠️ Unknown provider '{self.provider}', defaulting to local.")
self.provider = "local"
if self.local_pipeline is None:
print(f"Loading local model: distilgpt2...")
self.local_pipeline = pipeline("text-generation", model="distilgpt2")
# --------------------------
# Gemini call (modern SDK)
# --------------------------
def _call_gemini(self, query: str, context: str) -> str:
system_prompt = (
"You are a precise data analysis assistant. "
"Given the provided CONTEXT, answer the user's QUESTION accurately. "
"If calculations are needed, perform them. "
"Only respond with the final answer and no additional commentary or explanation."
)
user_content = f"CONTEXT:\n---\n{context}\n---\nQUESTION: {query}"
try:
config = types.GenerateContentConfig(
system_instruction=system_prompt,
temperature=0.1
)
response = self.client.models.generate_content(
model=self.model,
contents=user_content,
config=config
)
return response.text.strip()
except Exception as e:
return f"[Gemini API Error] {type(e).__name__}: {e}"
# --------------------------
# Local fallback
# --------------------------
def _call_local(self, query: str, context: str) -> str:
prompt = (
f"CONTEXT:\n{context}\n\n"
f"Based on the context, answer the following question:\n"
f"QUESTION: {query}\n"
f"ANSWER:"
)
result = self.local_pipeline(
prompt,
max_new_tokens=100,
do_sample=True,
truncation=True
)
generated_text = result[0]["generated_text"]
answer = generated_text[len(prompt):].strip()
if not answer or context in answer:
return "[Local model failed to generate a new answer and may have repeated the context]"
return answer
# --------------------------
# Main answer method
# --------------------------
def answer_question(self, query: str, context: str, sources: List[Dict]) -> Dict:
if self.provider == "gemini":
answer_text = self._call_gemini(query, context)
elif self.provider == "local":
answer_text = self._call_local(query, context)
else:
answer_text = f"[Error: Unknown provider '{self.provider}']"
provenance = [
{
"page": s["metadata"].get("page"),
"text": s["text"][:200],
"score": s.get("score", 0),
}
for s in sources
]
return {"text": answer_text, "sources": provenance}
|