state-rag / app.py
fudii0921's picture
Update app.py
e494dd8 verified
import gradio as gr
from groq import Groq
import cohere
import requests
from dotenv import load_dotenv
import os
import json
load_dotenv(verbose=True)
converted_data = []
documents = any
def convert_format(input_list):
output_list = []
for item in input_list:
output_list.append({"id": str(item['id']), "data": {"text": item['text'], "title": item['title']}})
return output_list
class StateParams:
def __init__(self):
from threading import Lock
self.llm_answer_text = ""
self.rag_answer_text = ""
self.converted_data = []
self.documents = []
self.lock = Lock() # Lock objects cannot be deepcopied
def get_llm_answer(self,prompt):
with self.lock:
# Get Answer from GROQ
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
system_prompt = {
"role": "system",
"content": "You are a helpful assistant, answer questions concisely."
}
# Set the user prompt
user_input = prompt
user_prompt = {
"role": "user", "content": user_input
}
# Initialize the chat history
chat_history = [system_prompt, user_prompt]
response = client.chat.completions.create(
model="llama-3.3-70b-versatile",
messages=chat_history,
max_tokens=1024,
temperature=0)
kekka = response.choices[0].message.content
self.llm_answer_text = kekka
return self.llm_answer_text
def get_rag_answer(self,prompt):
global converted_data
global documents
if len(converted_data) == 0:
document = requests.get("https://www.ryhintl.com/dbjson/getjson?sqlcmd=select `id` as `id`,`title`,`snippet` as `text` from cohere_documents where id = '8'")
documents1 = json.loads(document.content)
converted_data = convert_format(documents1)
documents = converted_data
with self.lock:
# Get answer from RAG
co = cohere.ClientV2(api_key=os.environ.get("COHERE_API_KEY"))
system_message = "You are a helpful assistant, answer questions concisely."
message = prompt
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": message},
]
response = co.chat(
model="command-r-plus-08-2024",
documents=documents,
messages=messages
)
self.rag_answer_text = response.message.content[0].text
return self.rag_answer_text
# Global dictionary to store user-specific instances
instances = {}
def initialize_instance(request: gr.Request):
instances[request.session_hash] = StateParams()
return "セッションが初期化されました。"
def cleanup_instance(request: gr.Request):
if request.session_hash in instances:
del instances[request.session_hash]
def llm_content(request: gr.Request, prompt: str):
if request.session_hash in instances:
instance = instances[request.session_hash]
return instance.get_llm_answer(prompt)
return "Error: セッションが初期化されていません。"
def rag_content(request: gr.Request, prompt: str):
if request.session_hash in instances:
instance = instances[request.session_hash]
return instance.get_rag_answer(prompt)
return "Error: セッションが初期化されていません。"
with gr.Blocks(title="ステート") as ryhrag:
output =gr.Textbox(label="ステート")
prompt = gr.Dropdown(
["YHプロジェクトの責任者は誰ですか?", "ER-RAGのアーキテクチャについて教えてください。", "YHプロジェクトのコストはいくらですか?"], label="プロンプト", info="Will add more animals later!"
)
llm_output = gr.Textbox(label="LLM")
rag_output = gr.Textbox(label="RAG")
llm_btn = gr.Button("LLM")
llm_btn.click(llm_content, inputs=prompt, outputs=llm_output)
rag_btn = gr.Button("RAG")
rag_btn.click(rag_content, inputs=prompt, outputs=rag_output)
# Initialize instance when page loads
ryhrag.load(initialize_instance, inputs=None, outputs=output)
# Clean up instance when page is closed/refreshed
ryhrag.close(cleanup_instance)
ryhrag.launch()