Dheeraj-13's picture
Update Gemini model to gemini-2.5-flash
ec3da7b
import os
from typing import List, Dict
from openai import OpenAI
import google.generativeai as genai
from ..observability.langfuse_client import observe
import torch
SYSTEM_PROMPT = """You are a grounded knowledge assistant.
Your goal is to answer the user's question using ONLY the provided context.
Rules:
1. Use the provided context to answer the question.
2. If the answer is not in the context, say "I don't know based on the provided documents."
3. Cite your sources for every fact using the format [doc_id:chunk_id].
4. Do not make up information.
5. Be concise and direct.
"""
# Global variable for lazy loading on the worker node
_local_pipeline = None
def _format_context(chunks: List[Dict]) -> str:
context_str = ""
for c in chunks:
location = c['metadata']['chunk_id']
text = c['content']
context_str += f"<SOURCE ID='{location}'>\n{text}\n</SOURCE>\n\n"
return context_str
def run_local_generation(query: str, context_chunks: List[Dict]) -> str:
"""
Standalone function to run generation on the GPU node.
Does NOT depend on GeneratorService instance (avoids pickling OpenAI client).
"""
global _local_pipeline
# 1. Format Context
context = _format_context(context_chunks)
# 2. Lazy Load Model (GPU Node)
if _local_pipeline is None:
print("Loading local Mistral-7B model (Lazy Load)...")
try:
from transformers import pipeline
model_id = "mistralai/Mistral-7B-Instruct-v0.3"
_local_pipeline = pipeline(
"text-generation",
model=model_id,
torch_dtype=torch.float16,
device_map="auto"
)
except Exception as e:
return f"Failed to load local model: {e}"
# 3. Prepare Messages
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
]
# 4. Run Inference
try:
outputs = _local_pipeline(
messages,
max_new_tokens=512,
do_sample=True,
temperature=0.1,
top_k=50,
top_p=0.95
)
result = outputs[0]['generated_text']
if isinstance(result, list):
return result[-1]['content']
return str(result)
except Exception as e:
return f"Generation Error: {e}"
class GeneratorService:
def __init__(self):
self.openai_client = None
self.openai_model = "gpt-4o-mini"
self.gemini_configured = False
# Initialize OpenAI
openai_key = os.getenv("OPENAI_API_KEY")
if openai_key:
self.openai_client = OpenAI(api_key=openai_key)
else:
print("Warning: OPENAI_API_KEY not found. OpenAI backend will not work.")
# Initialize Gemini
gemini_key = os.getenv("GEMINI_API_KEY")
if gemini_key:
genai.configure(api_key=gemini_key)
self.gemini_model = genai.GenerativeModel("gemini-2.5-flash")
self.gemini_configured = True
else:
print("Warning: GEMINI_API_KEY not found. Gemini backend will not work.")
@observe(name="generate")
def generate(self, query: str, context_chunks: List[Dict], backend: str = "openai") -> str:
# Dispatch to Local
if backend == "local":
return run_local_generation(query, context_chunks)
context = _format_context(context_chunks)
full_input = f"{SYSTEM_PROMPT}\n\nContext:\n{context}\n\nQuestion: {query}"
# Dispatch to Gemini
if backend == "gemini":
if not self.gemini_configured:
return "Error: Gemini backend selected but GEMINI_API_KEY not found."
try:
response = self.gemini_model.generate_content(full_input)
return response.text
except Exception as e:
return f"Gemini Error: {e}"
# OpenAI Logic (Default)
if self.openai_client is None:
return "Error: OpenAI backend selected but OPENAI_API_KEY not found."
try:
response = self.openai_client.chat.completions.create(
model=self.openai_model,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
]
)
return response.choices[0].message.content
except Exception as e:
return f"OpenAI Error: {str(e)}"
_shared_generator = None
def get_generator():
global _shared_generator
if _shared_generator is None:
_shared_generator = GeneratorService()
return _shared_generator