|
|
| import gradio as gr |
| from groq import Groq |
| import cohere |
| import requests |
| from dotenv import load_dotenv |
| import os |
| import json |
|
|
| load_dotenv(verbose=True) |
|
|
| converted_data = [] |
| documents = any |
|
|
| def convert_format(input_list): |
| output_list = [] |
| for item in input_list: |
| output_list.append({"id": str(item['id']), "data": {"text": item['text'], "title": item['title']}}) |
| return output_list |
|
|
| class StateParams: |
| def __init__(self): |
| from threading import Lock |
| self.llm_answer_text = "" |
| self.rag_answer_text = "" |
| self.converted_data = [] |
| self.documents = [] |
| self.lock = Lock() |
|
|
| def get_llm_answer(self,prompt): |
| with self.lock: |
| |
| client = Groq(api_key=os.environ.get("GROQ_API_KEY")) |
| system_prompt = { |
| "role": "system", |
| "content": "You are a helpful assistant, answer questions concisely." |
| } |
|
|
| |
| user_input = prompt |
| user_prompt = { |
| "role": "user", "content": user_input |
| } |
|
|
| |
| chat_history = [system_prompt, user_prompt] |
|
|
| response = client.chat.completions.create( |
| model="llama-3.3-70b-versatile", |
| messages=chat_history, |
| max_tokens=1024, |
| temperature=0) |
|
|
| kekka = response.choices[0].message.content |
| self.llm_answer_text = kekka |
| return self.llm_answer_text |
|
|
| def get_rag_answer(self,prompt): |
| global converted_data |
| global documents |
|
|
| if len(converted_data) == 0: |
| document = requests.get("https://www.ryhintl.com/dbjson/getjson?sqlcmd=select `id` as `id`,`title`,`snippet` as `text` from cohere_documents where id = '8'") |
| documents1 = json.loads(document.content) |
|
|
| converted_data = convert_format(documents1) |
| documents = converted_data |
|
|
| with self.lock: |
| |
| co = cohere.ClientV2(api_key=os.environ.get("COHERE_API_KEY")) |
| system_message = "You are a helpful assistant, answer questions concisely." |
| message = prompt |
| messages = [ |
| {"role": "system", "content": system_message}, |
| {"role": "user", "content": message}, |
| ] |
| response = co.chat( |
| model="command-r-plus-08-2024", |
| documents=documents, |
| messages=messages |
| ) |
|
|
| self.rag_answer_text = response.message.content[0].text |
| return self.rag_answer_text |
|
|
| |
| instances = {} |
|
|
| def initialize_instance(request: gr.Request): |
| instances[request.session_hash] = StateParams() |
| return "セッションが初期化されました。" |
|
|
| def cleanup_instance(request: gr.Request): |
| if request.session_hash in instances: |
| del instances[request.session_hash] |
|
|
| def llm_content(request: gr.Request, prompt: str): |
| if request.session_hash in instances: |
| instance = instances[request.session_hash] |
|
|
| return instance.get_llm_answer(prompt) |
| return "Error: セッションが初期化されていません。" |
|
|
| def rag_content(request: gr.Request, prompt: str): |
| if request.session_hash in instances: |
| instance = instances[request.session_hash] |
| return instance.get_rag_answer(prompt) |
| return "Error: セッションが初期化されていません。" |
|
|
| with gr.Blocks(title="ステート") as ryhrag: |
| output =gr.Textbox(label="ステート") |
| prompt = gr.Dropdown( |
| ["YHプロジェクトの責任者は誰ですか?", "ER-RAGのアーキテクチャについて教えてください。", "YHプロジェクトのコストはいくらですか?"], label="プロンプト", info="Will add more animals later!" |
| ) |
| llm_output = gr.Textbox(label="LLM") |
| rag_output = gr.Textbox(label="RAG") |
| llm_btn = gr.Button("LLM") |
| llm_btn.click(llm_content, inputs=prompt, outputs=llm_output) |
| rag_btn = gr.Button("RAG") |
| rag_btn.click(rag_content, inputs=prompt, outputs=rag_output) |
|
|
| |
| ryhrag.load(initialize_instance, inputs=None, outputs=output) |
| |
| ryhrag.close(cleanup_instance) |
|
|
| ryhrag.launch() |
|
|