Spaces:
Sleeping
Sleeping
File size: 4,896 Bytes
26fe9a7 32378fe 26fe9a7 d5cf328 26fe9a7 d5cf328 81b1c13 26fe9a7 81b1c13 32378fe d5cf328 32378fe b11fd72 32378fe ec3da7b 32378fe 26fe9a7 81b1c13 26fe9a7 81b1c13 32378fe 81b1c13 32378fe 81b1c13 d5cf328 81b1c13 26fe9a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import os
from typing import List, Dict
from openai import OpenAI
import google.generativeai as genai
from ..observability.langfuse_client import observe
import torch
SYSTEM_PROMPT = """You are a grounded knowledge assistant.
Your goal is to answer the user's question using ONLY the provided context.
Rules:
1. Use the provided context to answer the question.
2. If the answer is not in the context, say "I don't know based on the provided documents."
3. Cite your sources for every fact using the format [doc_id:chunk_id].
4. Do not make up information.
5. Be concise and direct.
"""
# Global variable for lazy loading on the worker node
_local_pipeline = None
def _format_context(chunks: List[Dict]) -> str:
context_str = ""
for c in chunks:
location = c['metadata']['chunk_id']
text = c['content']
context_str += f"<SOURCE ID='{location}'>\n{text}\n</SOURCE>\n\n"
return context_str
def run_local_generation(query: str, context_chunks: List[Dict]) -> str:
"""
Standalone function to run generation on the GPU node.
Does NOT depend on GeneratorService instance (avoids pickling OpenAI client).
"""
global _local_pipeline
# 1. Format Context
context = _format_context(context_chunks)
# 2. Lazy Load Model (GPU Node)
if _local_pipeline is None:
print("Loading local Mistral-7B model (Lazy Load)...")
try:
from transformers import pipeline
model_id = "mistralai/Mistral-7B-Instruct-v0.3"
_local_pipeline = pipeline(
"text-generation",
model=model_id,
torch_dtype=torch.float16,
device_map="auto"
)
except Exception as e:
return f"Failed to load local model: {e}"
# 3. Prepare Messages
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
]
# 4. Run Inference
try:
outputs = _local_pipeline(
messages,
max_new_tokens=512,
do_sample=True,
temperature=0.1,
top_k=50,
top_p=0.95
)
result = outputs[0]['generated_text']
if isinstance(result, list):
return result[-1]['content']
return str(result)
except Exception as e:
return f"Generation Error: {e}"
class GeneratorService:
def __init__(self):
self.openai_client = None
self.openai_model = "gpt-4o-mini"
self.gemini_configured = False
# Initialize OpenAI
openai_key = os.getenv("OPENAI_API_KEY")
if openai_key:
self.openai_client = OpenAI(api_key=openai_key)
else:
print("Warning: OPENAI_API_KEY not found. OpenAI backend will not work.")
# Initialize Gemini
gemini_key = os.getenv("GEMINI_API_KEY")
if gemini_key:
genai.configure(api_key=gemini_key)
self.gemini_model = genai.GenerativeModel("gemini-2.5-flash")
self.gemini_configured = True
else:
print("Warning: GEMINI_API_KEY not found. Gemini backend will not work.")
@observe(name="generate")
def generate(self, query: str, context_chunks: List[Dict], backend: str = "openai") -> str:
# Dispatch to Local
if backend == "local":
return run_local_generation(query, context_chunks)
context = _format_context(context_chunks)
full_input = f"{SYSTEM_PROMPT}\n\nContext:\n{context}\n\nQuestion: {query}"
# Dispatch to Gemini
if backend == "gemini":
if not self.gemini_configured:
return "Error: Gemini backend selected but GEMINI_API_KEY not found."
try:
response = self.gemini_model.generate_content(full_input)
return response.text
except Exception as e:
return f"Gemini Error: {e}"
# OpenAI Logic (Default)
if self.openai_client is None:
return "Error: OpenAI backend selected but OPENAI_API_KEY not found."
try:
response = self.openai_client.chat.completions.create(
model=self.openai_model,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
]
)
return response.choices[0].message.content
except Exception as e:
return f"OpenAI Error: {str(e)}"
_shared_generator = None
def get_generator():
global _shared_generator
if _shared_generator is None:
_shared_generator = GeneratorService()
return _shared_generator
|