File size: 4,896 Bytes
26fe9a7
 
 
32378fe
26fe9a7
d5cf328
26fe9a7
 
 
 
 
 
 
 
 
 
 
 
d5cf328
 
 
81b1c13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26fe9a7
 
 
81b1c13
32378fe
d5cf328
32378fe
 
 
 
b11fd72
 
32378fe
 
 
 
 
ec3da7b
32378fe
 
 
26fe9a7
 
 
 
81b1c13
 
 
26fe9a7
81b1c13
32378fe
81b1c13
32378fe
 
 
 
 
 
 
 
 
 
 
81b1c13
 
d5cf328
81b1c13
 
 
 
 
 
 
 
 
 
 
26fe9a7
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
from typing import List, Dict
from openai import OpenAI
import google.generativeai as genai
from ..observability.langfuse_client import observe
import torch

SYSTEM_PROMPT = """You are a grounded knowledge assistant. 
Your goal is to answer the user's question using ONLY the provided context.

Rules:
1. Use the provided context to answer the question.
2. If the answer is not in the context, say "I don't know based on the provided documents."
3. Cite your sources for every fact using the format [doc_id:chunk_id].
4. Do not make up information.
5. Be concise and direct.
"""

# Global variable for lazy loading on the worker node
_local_pipeline = None

def _format_context(chunks: List[Dict]) -> str:
    context_str = ""
    for c in chunks:
        location = c['metadata']['chunk_id']
        text = c['content']
        context_str += f"<SOURCE ID='{location}'>\n{text}\n</SOURCE>\n\n"
    return context_str

def run_local_generation(query: str, context_chunks: List[Dict]) -> str:
    """
    Standalone function to run generation on the GPU node.
    Does NOT depend on GeneratorService instance (avoids pickling OpenAI client).
    """
    global _local_pipeline
    
    # 1. Format Context
    context = _format_context(context_chunks)
    
    # 2. Lazy Load Model (GPU Node)
    if _local_pipeline is None:
        print("Loading local Mistral-7B model (Lazy Load)...")
        try:
            from transformers import pipeline
            model_id = "mistralai/Mistral-7B-Instruct-v0.3"
            _local_pipeline = pipeline(
                "text-generation", 
                model=model_id, 
                torch_dtype=torch.float16, 
                device_map="auto" 
            )
        except Exception as e:
            return f"Failed to load local model: {e}"

    # 3. Prepare Messages
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
    ]
    
    # 4. Run Inference
    try:
        outputs = _local_pipeline(
            messages,
            max_new_tokens=512,
            do_sample=True, 
            temperature=0.1, 
            top_k=50, 
            top_p=0.95
        )
        result = outputs[0]['generated_text']
        if isinstance(result, list):
             return result[-1]['content']
        return str(result)
    except Exception as e:
        return f"Generation Error: {e}"


class GeneratorService:
    def __init__(self):
        self.openai_client = None
        self.openai_model = "gpt-4o-mini"
        self.gemini_configured = False
        
        # Initialize OpenAI
        openai_key = os.getenv("OPENAI_API_KEY")
        if openai_key:
            self.openai_client = OpenAI(api_key=openai_key)
        else:
            print("Warning: OPENAI_API_KEY not found. OpenAI backend will not work.")
            
        # Initialize Gemini
        gemini_key = os.getenv("GEMINI_API_KEY")
        if gemini_key:
            genai.configure(api_key=gemini_key)
            self.gemini_model = genai.GenerativeModel("gemini-2.5-flash")
            self.gemini_configured = True
        else:
            print("Warning: GEMINI_API_KEY not found. Gemini backend will not work.")

    @observe(name="generate")
    def generate(self, query: str, context_chunks: List[Dict], backend: str = "openai") -> str:
        
        # Dispatch to Local
        if backend == "local":
            return run_local_generation(query, context_chunks)
            
        context = _format_context(context_chunks)
        full_input = f"{SYSTEM_PROMPT}\n\nContext:\n{context}\n\nQuestion: {query}"
        
        # Dispatch to Gemini
        if backend == "gemini":
            if not self.gemini_configured:
                return "Error: Gemini backend selected but GEMINI_API_KEY not found."
            try:
                response = self.gemini_model.generate_content(full_input)
                return response.text
            except Exception as e:
                return f"Gemini Error: {e}"
            
        # OpenAI Logic (Default)
        if self.openai_client is None:
             return "Error: OpenAI backend selected but OPENAI_API_KEY not found."

        try:
            response = self.openai_client.chat.completions.create(
                model=self.openai_model,
                messages=[
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
                ]
            )
            return response.choices[0].message.content
        except Exception as e:
            return f"OpenAI Error: {str(e)}"

_shared_generator = None

def get_generator():
    global _shared_generator
    if _shared_generator is None:
        _shared_generator = GeneratorService()
    return _shared_generator