File size: 4,512 Bytes
b413dc1
 
 
 
 
 
 
 
 
 
 
e494dd8
 
b413dc1
 
 
 
 
 
 
 
 
 
 
 
e494dd8
 
b413dc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e494dd8
 
 
 
 
 
 
 
 
 
b413dc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

import gradio as gr
from groq import Groq
import cohere
import requests
from dotenv import load_dotenv
import os
import json

load_dotenv(verbose=True)

converted_data = []
documents = any

def convert_format(input_list):
    output_list = []
    for item in input_list:
        output_list.append({"id": str(item['id']), "data": {"text": item['text'], "title": item['title']}})
    return output_list

class StateParams:
    def __init__(self):
        from threading import Lock
        self.llm_answer_text = ""
        self.rag_answer_text = ""
        self.converted_data = []
        self.documents = []
        self.lock = Lock()  # Lock objects cannot be deepcopied

    def get_llm_answer(self,prompt):
        with self.lock:
            # Get Answer from GROQ
            client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
            system_prompt = {
                "role": "system",
                "content": "You are a helpful assistant, answer questions concisely."
            }

            # Set the user prompt
            user_input = prompt
            user_prompt = {
                "role": "user", "content": user_input
            }

            # Initialize the chat history
            chat_history = [system_prompt, user_prompt]

            response = client.chat.completions.create(
                model="llama-3.3-70b-versatile",
                messages=chat_history,
                max_tokens=1024,
                temperature=0)

            kekka = response.choices[0].message.content
            self.llm_answer_text = kekka
            return self.llm_answer_text

    def get_rag_answer(self,prompt):
        global converted_data
        global documents

        if len(converted_data) == 0:
            document = requests.get("https://www.ryhintl.com/dbjson/getjson?sqlcmd=select `id` as `id`,`title`,`snippet` as `text` from cohere_documents where id = '8'")
            documents1 = json.loads(document.content)

            converted_data =  convert_format(documents1)
            documents = converted_data

        with self.lock:
            # Get answer from RAG
            co = cohere.ClientV2(api_key=os.environ.get("COHERE_API_KEY"))
            system_message = "You are a helpful assistant, answer questions concisely."
            message = prompt
            messages = [
                {"role": "system", "content": system_message},
                {"role": "user", "content": message},
            ]
            response = co.chat(
                model="command-r-plus-08-2024",
                documents=documents,
                messages=messages       
            )

            self.rag_answer_text = response.message.content[0].text
            return self.rag_answer_text

# Global dictionary to store user-specific instances
instances = {}

def initialize_instance(request: gr.Request):
    instances[request.session_hash] = StateParams()
    return "セッションが初期化されました。"

def cleanup_instance(request: gr.Request):
    if request.session_hash in instances:
        del instances[request.session_hash]

def llm_content(request: gr.Request, prompt: str):
    if request.session_hash in instances:
        instance = instances[request.session_hash]

        return instance.get_llm_answer(prompt)
    return "Error: セッションが初期化されていません。"

def rag_content(request: gr.Request, prompt: str):
    if request.session_hash in instances:
        instance = instances[request.session_hash]
        return instance.get_rag_answer(prompt)
    return "Error: セッションが初期化されていません。"

with gr.Blocks(title="ステート") as ryhrag:
    output =gr.Textbox(label="ステート")
    prompt = gr.Dropdown(
            ["YHプロジェクトの責任者は誰ですか?", "ER-RAGのアーキテクチャについて教えてください。", "YHプロジェクトのコストはいくらですか?"], label="プロンプト", info="Will add more animals later!"
        )
    llm_output = gr.Textbox(label="LLM")
    rag_output = gr.Textbox(label="RAG")
    llm_btn = gr.Button("LLM")
    llm_btn.click(llm_content, inputs=prompt, outputs=llm_output)
    rag_btn = gr.Button("RAG")
    rag_btn.click(rag_content, inputs=prompt, outputs=rag_output)

# Initialize instance when page loads
    ryhrag.load(initialize_instance, inputs=None, outputs=output)
    # Clean up instance when page is closed/refreshed
    ryhrag.close(cleanup_instance)

ryhrag.launch()