File size: 5,499 Bytes
4e3cee0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
from typing import List, Dict
from dotenv import dotenv_values
from transformers import pipeline

# --- Load from .env first, then fall back to system environment (for cloud deployment) ---
env_vars = dotenv_values(".env")  # returns a dict from the .env file

def get_env(key, default=None):
    """Get env var from .env first, then system environment"""
    return env_vars.get(key) or os.environ.get(key) or default

# --- FIX: Use the correct modern SDK import (google-genai) and initialize client ---
genai = None
_gemini_client = None
_api_key = get_env("GEMINI_API_KEY")
try:
    from google import genai
    from google.genai import types
    
    # Initialize client - check both .env and system env for cloud deployment
    if _api_key:
        _gemini_client = genai.Client(api_key=_api_key)
except ImportError:
    pass
except Exception as e:
    print(f"Warning: Failed to initialize Gemini client. Check API key/configuration. Error: {e}")

class LLMReader:
    """
    LLM Reader using Google Gemini (via GEMINI_API_KEY from .env or environment)
    Falls back to a local small model if unavailable.
    """

    def __init__(self, provider: str = "gemini"):
        self.provider = provider.lower()
        
        # Load from .env or system environment (for cloud deployment)
        self.model = get_env("VDOCRAG_LLM_MODEL", "gemini-2.5-flash")
        self.api_key = get_env("GEMINI_API_KEY")
        self.client = _gemini_client
        self.local_pipeline = None

        print("=" * 50)
        print(f"LLMReader Init: Loading GEMINI_API_KEY...")
        if self.api_key:
            print(f"LLMReader Init: SUCCESS. Key prefix: {self.api_key[:4]}...{self.api_key[-4:]}")
        else:
            print(f"LLMReader Init: FAILED. GEMINI_API_KEY not found.")
        print("=" * 50)

        if self.provider == "gemini":
            # Check for API key first - if missing, fall back to local
            if not self.api_key:
                print("⚠️ No GEMINI_API_KEY found, switching to local model.")
                self.provider = "local"
            elif genai is None:
                raise ImportError("Please install the modern Google GenAI SDK: `pip install google-genai`.")
            elif self.client is None:
                print("⚠️ Failed to initialize Gemini client, switching to local model.")
                self.provider = "local"

        if self.provider == "local":
            print(f"Loading local model: distilgpt2...")
            self.local_pipeline = pipeline("text-generation", model="distilgpt2")

        if self.provider not in ("gemini", "local"):
            print(f"⚠️ Unknown provider '{self.provider}', defaulting to local.")
            self.provider = "local"
            if self.local_pipeline is None:
                print(f"Loading local model: distilgpt2...")
                self.local_pipeline = pipeline("text-generation", model="distilgpt2")

    # --------------------------
    # Gemini call (modern SDK)
    # --------------------------
    def _call_gemini(self, query: str, context: str) -> str:
        system_prompt = (
            "You are a precise data analysis assistant. "
            "Given the provided CONTEXT, answer the user's QUESTION accurately. "
            "If calculations are needed, perform them. "
            "Only respond with the final answer and no additional commentary or explanation."
        )

        user_content = f"CONTEXT:\n---\n{context}\n---\nQUESTION: {query}"

        try:
            config = types.GenerateContentConfig(
                system_instruction=system_prompt,
                temperature=0.1
            )
            response = self.client.models.generate_content(
                model=self.model,
                contents=user_content,
                config=config
            )
            return response.text.strip()
        except Exception as e:
            return f"[Gemini API Error] {type(e).__name__}: {e}"

    # --------------------------
    # Local fallback
    # --------------------------
    def _call_local(self, query: str, context: str) -> str:
        prompt = (
            f"CONTEXT:\n{context}\n\n"
            f"Based on the context, answer the following question:\n"
            f"QUESTION: {query}\n"
            f"ANSWER:"
        )

        result = self.local_pipeline(
            prompt,
            max_new_tokens=100,
            do_sample=True,
            truncation=True
        )
        generated_text = result[0]["generated_text"]
        answer = generated_text[len(prompt):].strip()

        if not answer or context in answer:
            return "[Local model failed to generate a new answer and may have repeated the context]"
        return answer

    # --------------------------
    # Main answer method
    # --------------------------
    def answer_question(self, query: str, context: str, sources: List[Dict]) -> Dict:
        if self.provider == "gemini":
            answer_text = self._call_gemini(query, context)
        elif self.provider == "local":
            answer_text = self._call_local(query, context)
        else:
            answer_text = f"[Error: Unknown provider '{self.provider}']"

        provenance = [
            {
                "page": s["metadata"].get("page"),
                "text": s["text"][:200],
                "score": s.get("score", 0),
            }
            for s in sources
        ]

        return {"text": answer_text, "sources": provenance}