Spaces:
Sleeping
Sleeping
Commit
·
32378fe
1
Parent(s):
5cbdee2
Add Gemini backend support
Browse files- apps/web/app.py +1 -1
- requirements.txt +1 -0
- services/rag/generate.py +27 -4
apps/web/app.py
CHANGED
|
@@ -137,7 +137,7 @@ with gr.Blocks(title="RAG Assistant") as demo:
|
|
| 137 |
# Configuration Accordion
|
| 138 |
with gr.Accordion("Settings", open=True):
|
| 139 |
backend_radio = gr.Radio(
|
| 140 |
-
choices=["openai", "local"],
|
| 141 |
value="openai",
|
| 142 |
label="LLM Backend"
|
| 143 |
)
|
|
|
|
| 137 |
# Configuration Accordion
|
| 138 |
with gr.Accordion("Settings", open=True):
|
| 139 |
backend_radio = gr.Radio(
|
| 140 |
+
choices=["openai", "gemini", "local"],
|
| 141 |
value="openai",
|
| 142 |
label="LLM Backend"
|
| 143 |
)
|
requirements.txt
CHANGED
|
@@ -20,3 +20,4 @@ gradio==4.44.1
|
|
| 20 |
pydantic>=2.0,<3.0
|
| 21 |
sentencepiece
|
| 22 |
protobuf
|
|
|
|
|
|
| 20 |
pydantic>=2.0,<3.0
|
| 21 |
sentencepiece
|
| 22 |
protobuf
|
| 23 |
+
google-generativeai
|
services/rag/generate.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
from typing import List, Dict
|
| 3 |
from openai import OpenAI
|
|
|
|
| 4 |
from ..observability.langfuse_client import observe
|
| 5 |
import torch
|
| 6 |
|
|
@@ -79,12 +80,23 @@ class GeneratorService:
|
|
| 79 |
def __init__(self):
|
| 80 |
self.openai_client = None
|
| 81 |
self.openai_model = "gpt-4o-mini"
|
|
|
|
| 82 |
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
|
|
|
| 86 |
else:
|
| 87 |
print("Warning: OPENAI_API_KEY not found. OpenAI backend will not work.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
@observe(name="generate")
|
| 90 |
def generate(self, query: str, context_chunks: List[Dict], backend: str = "openai") -> str:
|
|
@@ -93,9 +105,20 @@ class GeneratorService:
|
|
| 93 |
if backend == "local":
|
| 94 |
return run_local_generation(query, context_chunks)
|
| 95 |
|
| 96 |
-
# OpenAI Logic
|
| 97 |
context = _format_context(context_chunks)
|
|
|
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
if self.openai_client is None:
|
| 100 |
return "Error: OpenAI backend selected but OPENAI_API_KEY not found."
|
| 101 |
|
|
|
|
| 1 |
import os
|
| 2 |
from typing import List, Dict
|
| 3 |
from openai import OpenAI
|
| 4 |
+
import google.generativeai as genai
|
| 5 |
from ..observability.langfuse_client import observe
|
| 6 |
import torch
|
| 7 |
|
|
|
|
| 80 |
def __init__(self):
|
| 81 |
self.openai_client = None
|
| 82 |
self.openai_model = "gpt-4o-mini"
|
| 83 |
+
self.gemini_configured = False
|
| 84 |
|
| 85 |
+
# Initialize OpenAI
|
| 86 |
+
openai_key = os.getenv("OPENAI_API_KEY")
|
| 87 |
+
if openai_key:
|
| 88 |
+
self.openai_client = OpenAI(api_key=openai_key)
|
| 89 |
else:
|
| 90 |
print("Warning: OPENAI_API_KEY not found. OpenAI backend will not work.")
|
| 91 |
+
|
| 92 |
+
# Initialize Gemini
|
| 93 |
+
gemini_key = os.getenv("GEMINI_API_KEY")
|
| 94 |
+
if gemini_key:
|
| 95 |
+
genai.configure(api_key=gemini_key)
|
| 96 |
+
self.gemini_model = genai.GenerativeModel("gemini-1.5-flash")
|
| 97 |
+
self.gemini_configured = True
|
| 98 |
+
else:
|
| 99 |
+
print("Warning: GEMINI_API_KEY not found. Gemini backend will not work.")
|
| 100 |
|
| 101 |
@observe(name="generate")
|
| 102 |
def generate(self, query: str, context_chunks: List[Dict], backend: str = "openai") -> str:
|
|
|
|
| 105 |
if backend == "local":
|
| 106 |
return run_local_generation(query, context_chunks)
|
| 107 |
|
|
|
|
| 108 |
context = _format_context(context_chunks)
|
| 109 |
+
full_input = f"{SYSTEM_PROMPT}\n\nContext:\n{context}\n\nQuestion: {query}"
|
| 110 |
|
| 111 |
+
# Dispatch to Gemini
|
| 112 |
+
if backend == "gemini":
|
| 113 |
+
if not self.gemini_configured:
|
| 114 |
+
return "Error: Gemini backend selected but GEMINI_API_KEY not found."
|
| 115 |
+
try:
|
| 116 |
+
response = self.gemini_model.generate_content(full_input)
|
| 117 |
+
return response.text
|
| 118 |
+
except Exception as e:
|
| 119 |
+
return f"Gemini Error: {e}"
|
| 120 |
+
|
| 121 |
+
# OpenAI Logic (Default)
|
| 122 |
if self.openai_client is None:
|
| 123 |
return "Error: OpenAI backend selected but OPENAI_API_KEY not found."
|
| 124 |
|