Dheeraj-13 commited on
Commit
32378fe
·
1 Parent(s): 5cbdee2

Add Gemini backend support

Browse files
Files changed (3) hide show
  1. apps/web/app.py +1 -1
  2. requirements.txt +1 -0
  3. services/rag/generate.py +27 -4
apps/web/app.py CHANGED
@@ -137,7 +137,7 @@ with gr.Blocks(title="RAG Assistant") as demo:
137
  # Configuration Accordion
138
  with gr.Accordion("Settings", open=True):
139
  backend_radio = gr.Radio(
140
- choices=["openai", "local"],
141
  value="openai",
142
  label="LLM Backend"
143
  )
 
137
  # Configuration Accordion
138
  with gr.Accordion("Settings", open=True):
139
  backend_radio = gr.Radio(
140
+ choices=["openai", "gemini", "local"],
141
  value="openai",
142
  label="LLM Backend"
143
  )
requirements.txt CHANGED
@@ -20,3 +20,4 @@ gradio==4.44.1
20
  pydantic>=2.0,<3.0
21
  sentencepiece
22
  protobuf
 
 
20
  pydantic>=2.0,<3.0
21
  sentencepiece
22
  protobuf
23
+ google-generativeai
services/rag/generate.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  from typing import List, Dict
3
  from openai import OpenAI
 
4
  from ..observability.langfuse_client import observe
5
  import torch
6
 
@@ -79,12 +80,23 @@ class GeneratorService:
79
  def __init__(self):
80
  self.openai_client = None
81
  self.openai_model = "gpt-4o-mini"
 
82
 
83
- api_key = os.getenv("OPENAI_API_KEY")
84
- if api_key:
85
- self.openai_client = OpenAI(api_key=api_key)
 
86
  else:
87
  print("Warning: OPENAI_API_KEY not found. OpenAI backend will not work.")
 
 
 
 
 
 
 
 
 
88
 
89
  @observe(name="generate")
90
  def generate(self, query: str, context_chunks: List[Dict], backend: str = "openai") -> str:
@@ -93,9 +105,20 @@ class GeneratorService:
93
  if backend == "local":
94
  return run_local_generation(query, context_chunks)
95
 
96
- # OpenAI Logic
97
  context = _format_context(context_chunks)
 
98
 
 
 
 
 
 
 
 
 
 
 
 
99
  if self.openai_client is None:
100
  return "Error: OpenAI backend selected but OPENAI_API_KEY not found."
101
 
 
1
  import os
2
  from typing import List, Dict
3
  from openai import OpenAI
4
+ import google.generativeai as genai
5
  from ..observability.langfuse_client import observe
6
  import torch
7
 
 
80
  def __init__(self):
81
  self.openai_client = None
82
  self.openai_model = "gpt-4o-mini"
83
+ self.gemini_configured = False
84
 
85
+ # Initialize OpenAI
86
+ openai_key = os.getenv("OPENAI_API_KEY")
87
+ if openai_key:
88
+ self.openai_client = OpenAI(api_key=openai_key)
89
  else:
90
  print("Warning: OPENAI_API_KEY not found. OpenAI backend will not work.")
91
+
92
+ # Initialize Gemini
93
+ gemini_key = os.getenv("GEMINI_API_KEY")
94
+ if gemini_key:
95
+ genai.configure(api_key=gemini_key)
96
+ self.gemini_model = genai.GenerativeModel("gemini-1.5-flash")
97
+ self.gemini_configured = True
98
+ else:
99
+ print("Warning: GEMINI_API_KEY not found. Gemini backend will not work.")
100
 
101
  @observe(name="generate")
102
  def generate(self, query: str, context_chunks: List[Dict], backend: str = "openai") -> str:
 
105
  if backend == "local":
106
  return run_local_generation(query, context_chunks)
107
 
 
108
  context = _format_context(context_chunks)
109
+ full_input = f"{SYSTEM_PROMPT}\n\nContext:\n{context}\n\nQuestion: {query}"
110
 
111
+ # Dispatch to Gemini
112
+ if backend == "gemini":
113
+ if not self.gemini_configured:
114
+ return "Error: Gemini backend selected but GEMINI_API_KEY not found."
115
+ try:
116
+ response = self.gemini_model.generate_content(full_input)
117
+ return response.text
118
+ except Exception as e:
119
+ return f"Gemini Error: {e}"
120
+
121
+ # OpenAI Logic (Default)
122
  if self.openai_client is None:
123
  return "Error: OpenAI backend selected but OPENAI_API_KEY not found."
124