tax-free commited on
Commit
c1466e0
·
verified ·
1 Parent(s): 9ff8ac5

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +203 -0
  2. merged_data.csv +0 -0
  3. requirements.txt +158 -0
app.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import random
3
+ import time
4
+ import csv
5
+ import os
6
+ import pandas as pd
7
+ from openai import OpenAI
8
+ import numpy as np
9
+ import ast
10
+ import json
11
+
12
+
13
+ # ユーザーの諸々の情報やプロンプトを受け取り返答する
14
+ def search_teacher(job_type, job_start_dates, tool_function_arguments):
15
+ time.sleep(1)
16
+
17
+ user_request = ', '.join(list(tool_function_arguments.values())[:-2])
18
+ user_message_embedded_vector = \
19
+ client.embeddings.create(input=[user_request.replace("\n", " ")], model='text-embedding-3-small').data[
20
+ 0].embedding
21
+
22
+ data = pd.read_csv(os.environ['DATA_PATH'])
23
+
24
+ # job_type の日本語を英語に変換
25
+ if job_type == '常勤':
26
+ job_type = 'full time'
27
+ elif job_type == '非常勤':
28
+ job_type = 'part time'
29
+
30
+ # job_start_dates の日本語を英語に変換
31
+ job_start_date_translation = {
32
+ '今年度': 'This Year',
33
+ '来年度': 'Next Year',
34
+ '来来年度': 'Year After Next'
35
+ }
36
+ # job_start_dates が文字列の場合リストに変換(単一選択の場合を考慮)
37
+ if isinstance(job_start_dates, str):
38
+ job_start_dates = [job_start_dates]
39
+ # 英語に変換
40
+ job_start_dates = [job_start_date_translation.get(date, date) for date in job_start_dates]
41
+
42
+ # job_type と job_start_dates でフィルタリング
43
+ filtered_data = data[(data['job_type'] == job_type) & (data['job_start_date'].isin(job_start_dates))].copy()
44
+
45
+ result = []
46
+
47
+ for index, row in filtered_data.iterrows():
48
+ teacher_embedded_vector = ast.literal_eval(row['embedding'])
49
+
50
+ similarity = np.dot(user_message_embedded_vector, teacher_embedded_vector) / (
51
+ np.linalg.norm(user_message_embedded_vector) * np.linalg.norm(teacher_embedded_vector))
52
+
53
+ if len(result) < 3:
54
+ result.append((index, similarity))
55
+ result.sort(key=lambda x: x[1], reverse=True)
56
+ else:
57
+ if similarity > result[-1][1]:
58
+ result[-1] = (index, similarity)
59
+ result.sort(key=lambda x: x[1], reverse=True)
60
+
61
+ formatted_result = []
62
+ for index, similarity in result:
63
+ name = filtered_data.loc[index, 'name']
64
+ temp = f"{name}: {similarity}"
65
+ formatted_result.append(temp)
66
+
67
+ return ', '.join(formatted_result)
68
+
69
+
70
+ def openai_api(job_type, job_start_dates, history):
71
+ # GPTにユーザーの入力を送信
72
+ message = client.beta.threads.messages.create(
73
+ thread_id=thread.id,
74
+ role="user",
75
+ content=history[-1][0]
76
+ )
77
+
78
+ # 送信した入力を実行
79
+ run = client.beta.threads.runs.create(
80
+ thread_id=thread.id,
81
+ assistant_id=assistant.id,
82
+ )
83
+
84
+ while True:
85
+ run = client.beta.threads.runs.retrieve(
86
+ thread_id=thread.id,
87
+ run_id=run.id
88
+ )
89
+ if run.status == 'completed':
90
+ break
91
+ elif run.status == 'requires_action':
92
+ tool_id = run.required_action.submit_tool_outputs.tool_calls[0].id
93
+ tool_function_arguments = json.loads(
94
+ run.required_action.submit_tool_outputs.tool_calls[0].function.arguments)
95
+
96
+ tool_function_output = search_teacher(job_type, job_start_dates, tool_function_arguments)
97
+
98
+ run = client.beta.threads.runs.submit_tool_outputs(
99
+ thread_id=thread.id,
100
+ run_id=run.id,
101
+ tool_outputs=[
102
+ {
103
+ "tool_call_id": tool_id,
104
+ "output": tool_function_output,
105
+ }
106
+ ]
107
+ )
108
+ time.sleep(3)
109
+
110
+ time.sleep(0.5)
111
+
112
+ messages = client.beta.threads.messages.list(
113
+ thread_id=thread.id,
114
+ order="asc"
115
+ )
116
+
117
+ return messages.data[-1].content[0].text.value
118
+
119
+
120
+ def user(user_job_type, user_job_start_date, user_message, history):
121
+ gr.Info(str(user_job_type))
122
+ gr.Info(str(user_job_start_date))
123
+
124
+ return None, history + [[user_message, None]]
125
+
126
+
127
+ def bot(job_type, job_start_date, history):
128
+ prompt = ""
129
+ # apiを叩くためにデータを加工するなりする
130
+ for chat in history[:-1]:
131
+ prompt += '"' + chat[0] + '", "' + chat[1] + '"'
132
+
133
+ bot_message = openai_api(job_type, job_start_date, history)
134
+
135
+ history[-1][1] = ""
136
+ for character in bot_message:
137
+ history[-1][1] += character
138
+ time.sleep(0.01)
139
+
140
+ yield history
141
+
142
+
143
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
144
+ global assistant
145
+ global thread
146
+ global client
147
+
148
+ client = OpenAI(max_retries=5)
149
+
150
+ assistant = client.beta.assistants.create(
151
+ name='connpath_demo',
152
+ instructions=os.environ['INSTRUCTIONS'],
153
+ model="gpt-4-0125-preview",
154
+ tools=[{
155
+ "type": "function",
156
+ "function": {
157
+ "name": "connpath_demo_gpt",
158
+ "description": "検索を支援する",
159
+ "parameters": {
160
+ "type": "object",
161
+ "properties": {
162
+ "educational_goals": {"type": "string", "description": "教育の目的"},
163
+ "student_profile": {"type": "string", "description": "対象のプロフィール"},
164
+ "required_skills_experience": {"type": "string", "description": "求められるスキル"},
165
+ "teaching_method_environment": {"type": "string", "description": "教育する環境"},
166
+ "evaluation_feedback": {"type": "string", "description": "教育の評価方法"}
167
+ },
168
+ "required": ["educational_goals", "student_profile", "required_skills_experience",
169
+ "teaching_method_environment", "evaluation_feedback"]
170
+ }
171
+ }
172
+ }]
173
+ )
174
+
175
+ thread = client.beta.threads.create()
176
+
177
+ # これがhistoryを保持
178
+ chatbot = gr.Chatbot(show_copy_button=True)
179
+
180
+ # これがuser_job_typeを保持 (str型)
181
+ job_type = gr.Radio(["常勤", "非常勤"],
182
+ label="Job Type",
183
+ info="探している雇用形態について")
184
+
185
+ # これがuser_job_start_dateを保持 (list型)
186
+ job_start_date = gr.CheckboxGroup(["今年度", "来年度", "来来年度"],
187
+ label="Start Date",
188
+ info="探している就業時期について")
189
+
190
+ # これがuser_messageを保持
191
+ msg = gr.Textbox(label="input message")
192
+ clear = gr.Button("clear history")
193
+
194
+ msg.submit(user,
195
+ [job_type, job_start_date, msg, chatbot],
196
+ [msg, chatbot],
197
+ queue=True
198
+ ).then(bot, [job_type, job_start_date, chatbot], chatbot)
199
+ clear.click(lambda: None, None, chatbot, queue=False)
200
+
201
+ if __name__ == "__main__":
202
+ demo.queue()
203
+ demo.launch(auth=(os.environ['USER_NAME'], os.environ['PASSWORD']))
merged_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohttp==3.9.3
3
+ aiosignal==1.3.1
4
+ altair==5.2.0
5
+ annotated-types==0.6.0
6
+ anyio==3.7.1
7
+ argon2-cffi==23.1.0
8
+ argon2-cffi-bindings==21.2.0
9
+ arrow==1.2.3
10
+ asttokens==2.2.1
11
+ async-lru==2.0.4
12
+ async-timeout==4.0.3
13
+ attrs==23.1.0
14
+ Babel==2.12.1
15
+ backcall==0.2.0
16
+ beautifulsoup4==4.12.2
17
+ bleach==6.0.0
18
+ certifi==2023.7.22
19
+ cffi==1.15.1
20
+ charset-normalizer==3.2.0
21
+ click==8.1.7
22
+ colorama==0.4.6
23
+ comm==0.1.4
24
+ contourpy==1.1.0
25
+ cycler==0.11.0
26
+ dataclasses-json==0.6.4
27
+ debugpy==1.6.7.post1
28
+ decorator==5.1.1
29
+ defusedxml==0.7.1
30
+ distro==1.9.0
31
+ et-xmlfile==1.1.0
32
+ exceptiongroup==1.1.3
33
+ executing==1.2.0
34
+ fastapi==0.109.2
35
+ fastjsonschema==2.18.0
36
+ ffmpy==0.3.2
37
+ filelock==3.13.1
38
+ fonttools==4.42.0
39
+ fqdn==1.5.1
40
+ frozenlist==1.4.1
41
+ fsspec==2024.2.0
42
+ gradio==4.19.2
43
+ gradio_client==0.10.1
44
+ greenlet==3.0.3
45
+ h11==0.14.0
46
+ httpcore==1.0.4
47
+ httpx==0.27.0
48
+ huggingface-hub==0.20.3
49
+ idna==3.4
50
+ importlib-metadata==6.8.0
51
+ importlib-resources==6.0.1
52
+ ipykernel==6.25.1
53
+ ipython==8.14.0
54
+ isoduration==20.11.0
55
+ jedi==0.19.0
56
+ Jinja2==3.1.2
57
+ json5==0.9.14
58
+ jsonpatch==1.33
59
+ jsonpointer==2.4
60
+ jsonschema==4.19.0
61
+ jsonschema-specifications==2023.7.1
62
+ jupyter-events==0.7.0
63
+ jupyter-lsp==2.2.0
64
+ jupyter_client==8.3.0
65
+ jupyter_core==5.3.1
66
+ jupyter_server==2.7.1
67
+ jupyter_server_terminals==0.4.4
68
+ jupyterlab==4.0.5
69
+ jupyterlab-pygments==0.2.2
70
+ jupyterlab_server==2.24.0
71
+ kiwisolver==1.4.4
72
+ langchain==0.1.9
73
+ langchain-community==0.0.24
74
+ langchain-core==0.1.26
75
+ langsmith==0.1.7
76
+ markdown-it-py==3.0.0
77
+ MarkupSafe==2.1.3
78
+ marshmallow==3.20.2
79
+ matplotlib==3.7.2
80
+ matplotlib-inline==0.1.6
81
+ mdurl==0.1.2
82
+ mistune==3.0.1
83
+ multidict==6.0.5
84
+ mypy-extensions==1.0.0
85
+ nbclient==0.8.0
86
+ nbconvert==7.7.3
87
+ nbformat==5.9.2
88
+ nest-asyncio==1.5.7
89
+ notebook_shim==0.2.3
90
+ numpy==1.25.2
91
+ openai==1.12.0
92
+ openpyxl==3.1.2
93
+ orjson==3.9.15
94
+ overrides==7.4.0
95
+ packaging==23.2
96
+ pandas==2.0.3
97
+ pandocfilters==1.5.0
98
+ parso==0.8.3
99
+ pexpect==4.8.0
100
+ pickleshare==0.7.5
101
+ Pillow==10.0.0
102
+ platformdirs==3.10.0
103
+ prometheus-client==0.17.1
104
+ prompt-toolkit==3.0.39
105
+ psutil==5.9.5
106
+ ptyprocess==0.7.0
107
+ pure-eval==0.2.2
108
+ pycparser==2.21
109
+ pydantic==2.6.2
110
+ pydantic_core==2.16.3
111
+ pydub==0.25.1
112
+ Pygments==2.16.1
113
+ pyparsing==3.0.9
114
+ python-dateutil==2.8.2
115
+ python-json-logger==2.0.7
116
+ python-multipart==0.0.9
117
+ pytz==2023.3
118
+ PyYAML==6.0.1
119
+ pyzmq==25.1.1
120
+ referencing==0.30.2
121
+ requests==2.31.0
122
+ rfc3339-validator==0.1.4
123
+ rfc3986-validator==0.1.1
124
+ rich==13.7.0
125
+ rpds-py==0.9.2
126
+ ruff==0.2.2
127
+ semantic-version==2.10.0
128
+ Send2Trash==1.8.2
129
+ shellingham==1.5.4
130
+ six==1.16.0
131
+ sniffio==1.3.0
132
+ soupsieve==2.4.1
133
+ SQLAlchemy==2.0.27
134
+ stack-data==0.6.2
135
+ starlette==0.36.3
136
+ tenacity==8.2.3
137
+ terminado==0.17.1
138
+ tinycss2==1.2.1
139
+ tomli==2.0.1
140
+ tomlkit==0.12.0
141
+ toolz==0.12.1
142
+ tornado==6.3.3
143
+ tqdm==4.66.2
144
+ traitlets==5.9.0
145
+ typer==0.9.0
146
+ typing-inspect==0.9.0
147
+ typing_extensions==4.9.0
148
+ tzdata==2023.3
149
+ uri-template==1.3.0
150
+ urllib3==2.0.4
151
+ uvicorn==0.27.1
152
+ wcwidth==0.2.6
153
+ webcolors==1.13
154
+ webencodings==0.5.1
155
+ websocket-client==1.6.1
156
+ websockets==11.0.3
157
+ yarl==1.9.4
158
+ zipp==3.16.2