Raft-ML-Dep commited on
Commit
ca70f47
·
verified ·
1 Parent(s): 9d49fb4

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +108 -0
  2. requirements.txt +173 -0
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import asyncio
4
+ import pandas as pd
5
+ from datetime import datetime
6
+ from langchain.schema.output_parser import StrOutputParser
7
+ from langchain.chat_models import ChatOpenAI
8
+ from langchain.prompts.chat import (
9
+ ChatPromptTemplate,
10
+ HumanMessagePromptTemplate,
11
+ SystemMessagePromptTemplate,
12
+ )
13
+
14
+ # Load OpenAI model with API key
15
+ openai_key = os.getenv('OPENAI_API_KEY')
16
+ model = ChatOpenAI(model_name="gpt-4-0125-preview", temperature=0.4)
17
+
18
+ # System prompt templates
19
+ raft_base_template = os.getenv('RAFT_BASE_TEMPLATE')
20
+
21
+ # System prompt template
22
+ customer_template = os.getenv('CUSTOMER_TEMPLATE')
23
+
24
+ raft_metaprompt_template = os.getenv('RAFT_METAPROMPT_TEMPLATE')
25
+ # Initialize DataFrame
26
+ results_df = pd.DataFrame(columns=['Experiment Number','Temperature', 'PromptID','Business Idea', 'Generated Domains'])
27
+
28
+ def extract_content_inside_brackets(json_string):
29
+ # Находим позицию первой открывающей и последней закрывающей скобки
30
+ start = json_string.find('[')
31
+ end = json_string.find(']', start)
32
+
33
+ # Если скобки найдены, извлекаем содержимое между ними
34
+ if start != -1 and end != -1:
35
+ content_inside_brackets = json_string[start+1:end]
36
+ return content_inside_brackets
37
+ else:
38
+ # Возвращаем None, если скобки не найдены
39
+ return None
40
+ list_of_prompts = ['customer_template', 'raft_base_template', 'raft_metaprompt_template']
41
+ # Function to generate domain names based on different templates
42
+ async def generate_domain_names(business_idea, temperature):
43
+ model.temperature = temperature # Update model temperature
44
+ prompts = [customer_template, raft_base_template, raft_metaprompt_template]
45
+ tasks = []
46
+ for prompt in prompts:
47
+ system_message_prompt = SystemMessagePromptTemplate.from_template(prompt)
48
+ human_message_prompt = HumanMessagePromptTemplate.from_template("User's input: ~###~ {user_input}~###~,")
49
+ combined_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
50
+ json_chain = combined_prompt | model | StrOutputParser() | extract_content_inside_brackets
51
+ task = asyncio.to_thread(json_chain.invoke, {"user_input": business_idea})
52
+ tasks.append(task)
53
+ domains = await asyncio.gather(*tasks)
54
+ # Update DataFrame and save to JSON
55
+ experiment_number = datetime.now().strftime("%Y%m%d%H%M%S")
56
+ for i, domain in enumerate(domains):
57
+ new_row = {
58
+ 'Experiment Number': experiment_number,
59
+ 'Business Idea': business_idea,
60
+ 'Temperature': temperature,
61
+ 'PromptID': list_of_prompts[i],
62
+ 'Generated Domains': domain
63
+ }
64
+ global results_df
65
+ results_df = pd.concat([results_df, pd.DataFrame([new_row])], ignore_index=True)
66
+ save_to_json(results_df)
67
+ return domains
68
+
69
+ # Function to save DataFrame to JSON
70
+ def save_to_json(df):
71
+ json_path = 'domain_name_generation_results.json'
72
+ df.to_json(json_path, orient='records', lines=True)
73
+
74
+ # Function to download the DataFrame as an Excel file
75
+ def download_excel():
76
+ df = pd.read_json('domain_name_generation_results.json', lines=True)
77
+ excel_path = 'domain_name_generation_results.xlsx'
78
+ df.to_excel(excel_path, index=False)
79
+ return excel_path
80
+
81
+ # Setup Gradio interface
82
+ def setup_interface():
83
+ with gr.Blocks() as iface:
84
+ with gr.Row():
85
+ text_input = gr.Textbox(label="Enter your business idea", placeholder="Type here...", lines=2)
86
+ temperature_slider = gr.Slider(minimum=0, maximum=1, step=0.01, label="Set Temperature")
87
+ submit_button = gr.Button("Generate")
88
+ with gr.Row():
89
+ output1 = gr.Textbox(label="Output for Customer Template")
90
+ output2 = gr.Textbox(label="Output for Base Raft Template")
91
+ output3 = gr.Textbox(label="Output for MetaPrompt Raft Template")
92
+ with gr.Row():
93
+ download_btn = gr.Button("Download Excel")
94
+ submit_button.click(
95
+ fn=generate_domain_names,
96
+ inputs=[text_input, temperature_slider],
97
+ outputs=[output1, output2, output3]
98
+ )
99
+ download_btn.click(
100
+ fn=download_excel,
101
+ inputs=None,
102
+ outputs=gr.File(label="Download Excel")
103
+ )
104
+ return iface
105
+
106
+
107
+ iface = setup_interface()
108
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1 ; python_version >= "3.10" and python_version < "4.0"
2
+ aiohttp-retry==2.8.3 ; python_version >= "3.10" and python_version < "4.0"
3
+ aiohttp==3.9.5 ; python_version >= "3.10" and python_version < "4.0"
4
+ aiosignal==1.3.1 ; python_version >= "3.10" and python_version < "4.0"
5
+ altair==5.3.0 ; python_version >= "3.10" and python_version < "4.0"
6
+ amqp==5.2.0 ; python_version >= "3.10" and python_version < "4.0"
7
+ annotated-types==0.6.0 ; python_version >= "3.10" and python_version < "4.0"
8
+ antlr4-python3-runtime==4.9.3 ; python_version >= "3.10" and python_version < "4.0"
9
+ anyio==4.3.0 ; python_version >= "3.10" and python_version < "4.0"
10
+ appdirs==1.4.4 ; python_version >= "3.10" and python_version < "4.0"
11
+ appnope==0.1.4 ; python_version >= "3.10" and python_version < "4.0" and platform_system == "Darwin"
12
+ asttokens==2.4.1 ; python_version >= "3.10" and python_version < "4.0"
13
+ asyncssh==2.14.2 ; python_version >= "3.10" and python_version < "4.0"
14
+ atpublic==4.1.0 ; python_version >= "3.10" and python_version < "4.0"
15
+ attrs==23.2.0 ; python_version >= "3.10" and python_version < "4.0"
16
+ billiard==4.2.0 ; python_version >= "3.10" and python_version < "4.0"
17
+ celery==5.4.0 ; python_version >= "3.10" and python_version < "4.0"
18
+ certifi==2024.2.2 ; python_version >= "3.10" and python_version < "4.0"
19
+ cffi==1.16.0 ; python_version >= "3.10" and python_version < "4.0"
20
+ charset-normalizer==3.3.2 ; python_version >= "3.10" and python_version < "4.0"
21
+ click-didyoumean==0.3.1 ; python_version >= "3.10" and python_version < "4.0"
22
+ click-plugins==1.1.1 ; python_version >= "3.10" and python_version < "4.0"
23
+ click-repl==0.3.0 ; python_version >= "3.10" and python_version < "4.0"
24
+ click==8.1.7 ; python_version >= "3.10" and python_version < "4.0"
25
+ colorama==0.4.6 ; python_version >= "3.10" and python_version < "4.0"
26
+ comm==0.2.2 ; python_version >= "3.10" and python_version < "4.0"
27
+ configobj==5.0.8 ; python_version >= "3.10" and python_version < "4.0"
28
+ contourpy==1.2.1 ; python_version >= "3.10" and python_version < "4.0"
29
+ cryptography==42.0.5 ; python_version >= "3.10" and python_version < "4.0"
30
+ cycler==0.12.1 ; python_version >= "3.10" and python_version < "4.0"
31
+ dataclasses-json==0.6.4 ; python_version >= "3.10" and python_version < "4.0"
32
+ debugpy==1.8.1 ; python_version >= "3.10" and python_version < "4.0"
33
+ decorator==5.1.1 ; python_version >= "3.10" and python_version < "4.0"
34
+ dictdiffer==0.9.0 ; python_version >= "3.10" and python_version < "4.0"
35
+ diskcache==5.6.3 ; python_version >= "3.10" and python_version < "4.0"
36
+ distro==1.9.0 ; python_version >= "3.10" and python_version < "4.0"
37
+ dpath==2.1.6 ; python_version >= "3.10" and python_version < "4.0"
38
+ dulwich==0.21.7 ; python_version >= "3.10" and python_version < "4.0"
39
+ dvc-data==3.15.1 ; python_version >= "3.10" and python_version < "4.0"
40
+ dvc-http==2.32.0 ; python_version >= "3.10" and python_version < "4.0"
41
+ dvc-objects==5.1.0 ; python_version >= "3.10" and python_version < "4.0"
42
+ dvc-render==1.0.2 ; python_version >= "3.10" and python_version < "4.0"
43
+ dvc-studio-client==0.20.0 ; python_version >= "3.10" and python_version < "4.0"
44
+ dvc-task==0.4.0 ; python_version >= "3.10" and python_version < "4.0"
45
+ dvc==3.50.0 ; python_version >= "3.10" and python_version < "4.0"
46
+ entrypoints==0.4 ; python_version >= "3.10" and python_version < "4.0"
47
+ et-xmlfile==1.1.0 ; python_version >= "3.10" and python_version < "4.0"
48
+ executing==2.0.1 ; python_version >= "3.10" and python_version < "4.0"
49
+ fastapi==0.110.2 ; python_version >= "3.10" and python_version < "4.0"
50
+ ffmpy==0.3.2 ; python_version >= "3.10" and python_version < "4.0"
51
+ filelock==3.13.4 ; python_version >= "3.10" and python_version < "4.0"
52
+ flatten-dict==0.4.2 ; python_version >= "3.10" and python_version < "4.0"
53
+ flufl-lock==7.1.1 ; python_version >= "3.10" and python_version < "4.0"
54
+ fonttools==4.51.0 ; python_version >= "3.10" and python_version < "4.0"
55
+ frozenlist==1.4.1 ; python_version >= "3.10" and python_version < "4.0"
56
+ fsspec==2024.3.1 ; python_version >= "3.10" and python_version < "4.0"
57
+ fsspec[http]==2024.3.1 ; python_version >= "3.10" and python_version < "4.0"
58
+ fsspec[tqdm]==2024.3.1 ; python_version >= "3.10" and python_version < "4.0"
59
+ funcy==2.0 ; python_version >= "3.10" and python_version < "4.0"
60
+ gitdb==4.0.11 ; python_version >= "3.10" and python_version < "4.0"
61
+ gitpython==3.1.43 ; python_version >= "3.10" and python_version < "4.0"
62
+ gradio-client==0.15.1 ; python_version >= "3.10" and python_version < "4.0"
63
+ gradio==4.27.0 ; python_version >= "3.10" and python_version < "4.0"
64
+ grandalf==0.8 ; python_version >= "3.10" and python_version < "4.0"
65
+ greenlet==3.0.3 ; python_version >= "3.10" and python_version < "4.0" and (platform_machine == "win32" or platform_machine == "WIN32" or platform_machine == "AMD64" or platform_machine == "amd64" or platform_machine == "x86_64" or platform_machine == "ppc64le" or platform_machine == "aarch64")
66
+ gto==1.7.1 ; python_version >= "3.10" and python_version < "4.0"
67
+ h11==0.14.0 ; python_version >= "3.10" and python_version < "4.0"
68
+ httpcore==1.0.5 ; python_version >= "3.10" and python_version < "4.0"
69
+ httpx==0.27.0 ; python_version >= "3.10" and python_version < "4.0"
70
+ huggingface-hub==0.22.2 ; python_version >= "3.10" and python_version < "4.0"
71
+ hydra-core==1.3.2 ; python_version >= "3.10" and python_version < "4.0"
72
+ idna==3.7 ; python_version >= "3.10" and python_version < "4.0"
73
+ importlib-resources==6.4.0 ; python_version >= "3.10" and python_version < "4.0"
74
+ ipykernel==6.29.4 ; python_version >= "3.10" and python_version < "4.0"
75
+ ipython==8.23.0 ; python_version >= "3.10" and python_version < "4.0"
76
+ iterative-telemetry==0.0.8 ; python_version >= "3.10" and python_version < "4.0"
77
+ jedi==0.19.1 ; python_version >= "3.10" and python_version < "4.0"
78
+ jinja2==3.1.3 ; python_version >= "3.10" and python_version < "4.0"
79
+ jsonpatch==1.33 ; python_version >= "3.10" and python_version < "4.0"
80
+ jsonpointer==2.4 ; python_version >= "3.10" and python_version < "4.0"
81
+ jsonschema-specifications==2023.12.1 ; python_version >= "3.10" and python_version < "4.0"
82
+ jsonschema==4.21.1 ; python_version >= "3.10" and python_version < "4.0"
83
+ jupyter-client==8.6.1 ; python_version >= "3.10" and python_version < "4.0"
84
+ jupyter-core==5.7.2 ; python_version >= "3.10" and python_version < "4.0"
85
+ kiwisolver==1.4.5 ; python_version >= "3.10" and python_version < "4.0"
86
+ kombu==5.3.7 ; python_version >= "3.10" and python_version < "4.0"
87
+ langchain-community==0.0.33 ; python_version >= "3.10" and python_version < "4.0"
88
+ langchain-core==0.1.44 ; python_version >= "3.10" and python_version < "4.0"
89
+ langchain-text-splitters==0.0.1 ; python_version >= "3.10" and python_version < "4.0"
90
+ langchain==0.1.16 ; python_version >= "3.10" and python_version < "4.0"
91
+ langsmith==0.1.48 ; python_version >= "3.10" and python_version < "4.0"
92
+ markdown-it-py==3.0.0 ; python_version >= "3.10" and python_version < "4.0"
93
+ markupsafe==2.1.5 ; python_version >= "3.10" and python_version < "4.0"
94
+ marshmallow==3.21.1 ; python_version >= "3.10" and python_version < "4.0"
95
+ matplotlib-inline==0.1.7 ; python_version >= "3.10" and python_version < "4.0"
96
+ matplotlib==3.8.4 ; python_version >= "3.10" and python_version < "4.0"
97
+ mdurl==0.1.2 ; python_version >= "3.10" and python_version < "4.0"
98
+ multidict==6.0.5 ; python_version >= "3.10" and python_version < "4.0"
99
+ mypy-extensions==1.0.0 ; python_version >= "3.10" and python_version < "4.0"
100
+ nest-asyncio==1.6.0 ; python_version >= "3.10" and python_version < "4.0"
101
+ networkx==3.3 ; python_version >= "3.10" and python_version < "4.0"
102
+ numpy==1.26.4 ; python_version >= "3.10" and python_version < "4.0"
103
+ omegaconf==2.3.0 ; python_version >= "3.10" and python_version < "4.0"
104
+ openai==1.21.2 ; python_version >= "3.10" and python_version < "4.0"
105
+ openpyxl==3.1.2 ; python_version >= "3.10" and python_version < "4.0"
106
+ orjson==3.10.1 ; python_version >= "3.10" and python_version < "4.0"
107
+ packaging==23.2 ; python_version >= "3.10" and python_version < "4.0"
108
+ pandas==2.2.2 ; python_version >= "3.10" and python_version < "4.0"
109
+ parso==0.8.4 ; python_version >= "3.10" and python_version < "4.0"
110
+ pathspec==0.12.1 ; python_version >= "3.10" and python_version < "4.0"
111
+ pexpect==4.9.0 ; python_version >= "3.10" and python_version < "4.0" and (sys_platform != "win32" and sys_platform != "emscripten")
112
+ pillow==10.3.0 ; python_version >= "3.10" and python_version < "4.0"
113
+ platformdirs==3.11.0 ; python_version >= "3.10" and python_version < "4.0"
114
+ prompt-toolkit==3.0.43 ; python_version >= "3.10" and python_version < "4.0"
115
+ psutil==5.9.8 ; python_version >= "3.10" and python_version < "4.0"
116
+ ptyprocess==0.7.0 ; python_version >= "3.10" and python_version < "4.0" and (sys_platform != "win32" and sys_platform != "emscripten")
117
+ pure-eval==0.2.2 ; python_version >= "3.10" and python_version < "4.0"
118
+ pycparser==2.22 ; python_version >= "3.10" and python_version < "4.0"
119
+ pydantic-core==2.18.1 ; python_version >= "3.10" and python_version < "4.0"
120
+ pydantic==2.7.0 ; python_version >= "3.10" and python_version < "4.0"
121
+ pydot==2.0.0 ; python_version >= "3.10" and python_version < "4.0"
122
+ pydub==0.25.1 ; python_version >= "3.10" and python_version < "4.0"
123
+ pygit2==1.14.1 ; python_version >= "3.10" and python_version < "4.0"
124
+ pygments==2.17.2 ; python_version >= "3.10" and python_version < "4.0"
125
+ pygtrie==2.5.0 ; python_version >= "3.10" and python_version < "4.0"
126
+ pyparsing==3.1.2 ; python_version >= "3.10" and python_version < "4.0"
127
+ python-dateutil==2.9.0.post0 ; python_version >= "3.10" and python_version < "4.0"
128
+ python-dotenv==1.0.1 ; python_version >= "3.10" and python_version < "4.0"
129
+ python-multipart==0.0.9 ; python_version >= "3.10" and python_version < "4.0"
130
+ pytz==2024.1 ; python_version >= "3.10" and python_version < "4.0"
131
+ pywin32==306 ; python_version >= "3.10" and python_version < "4.0" and sys_platform == "win32"
132
+ pyyaml==6.0.1 ; python_version >= "3.10" and python_version < "4.0"
133
+ pyzmq==26.0.0 ; python_version >= "3.10" and python_version < "4.0"
134
+ referencing==0.34.0 ; python_version >= "3.10" and python_version < "4.0"
135
+ requests==2.31.0 ; python_version >= "3.10" and python_version < "4.0"
136
+ rich==13.7.1 ; python_version >= "3.10" and python_version < "4.0"
137
+ rpds-py==0.18.0 ; python_version >= "3.10" and python_version < "4.0"
138
+ ruamel-yaml-clib==0.2.8 ; platform_python_implementation == "CPython" and python_version < "3.13" and python_version >= "3.10"
139
+ ruamel-yaml==0.18.6 ; python_version >= "3.10" and python_version < "4.0"
140
+ ruff==0.4.1 ; python_version >= "3.10" and python_version < "4.0" and sys_platform != "emscripten"
141
+ scmrepo==3.3.1 ; python_version >= "3.10" and python_version < "4.0"
142
+ semantic-version==2.10.0 ; python_version >= "3.10" and python_version < "4.0"
143
+ semver==3.0.2 ; python_version >= "3.10" and python_version < "4.0"
144
+ setuptools==69.5.1 ; python_version >= "3.10" and python_version < "4.0"
145
+ shellingham==1.5.4 ; python_version >= "3.10" and python_version < "4.0"
146
+ shortuuid==1.0.13 ; python_version >= "3.10" and python_version < "4.0"
147
+ shtab==1.7.1 ; python_version >= "3.10" and python_version < "4.0"
148
+ six==1.16.0 ; python_version >= "3.10" and python_version < "4.0"
149
+ smmap==5.0.1 ; python_version >= "3.10" and python_version < "4.0"
150
+ sniffio==1.3.1 ; python_version >= "3.10" and python_version < "4.0"
151
+ sqlalchemy==2.0.29 ; python_version >= "3.10" and python_version < "4.0"
152
+ sqltrie==0.11.0 ; python_version >= "3.10" and python_version < "4.0"
153
+ stack-data==0.6.3 ; python_version >= "3.10" and python_version < "4.0"
154
+ starlette==0.37.2 ; python_version >= "3.10" and python_version < "4.0"
155
+ tabulate==0.9.0 ; python_version >= "3.10" and python_version < "4.0"
156
+ tenacity==8.2.3 ; python_version >= "3.10" and python_version < "4.0"
157
+ tomlkit==0.12.0 ; python_version >= "3.10" and python_version < "4.0"
158
+ toolz==0.12.1 ; python_version >= "3.10" and python_version < "4.0"
159
+ tornado==6.4 ; python_version >= "3.10" and python_version < "4.0"
160
+ tqdm==4.66.2 ; python_version >= "3.10" and python_version < "4.0"
161
+ traitlets==5.14.2 ; python_version >= "3.10" and python_version < "4.0"
162
+ typer==0.12.3 ; python_version >= "3.10" and python_version < "4.0"
163
+ typing-extensions==4.11.0 ; python_version >= "3.10" and python_version < "4.0"
164
+ typing-inspect==0.9.0 ; python_version >= "3.10" and python_version < "4.0"
165
+ tzdata==2024.1 ; python_version >= "3.10" and python_version < "4.0"
166
+ urllib3==2.2.1 ; python_version >= "3.10" and python_version < "4.0"
167
+ uvicorn==0.29.0 ; python_version >= "3.10" and python_version < "4.0" and sys_platform != "emscripten"
168
+ vine==5.1.0 ; python_version >= "3.10" and python_version < "4.0"
169
+ voluptuous==0.14.2 ; python_version >= "3.10" and python_version < "4.0"
170
+ wcwidth==0.2.13 ; python_version >= "3.10" and python_version < "4.0"
171
+ websockets==11.0.3 ; python_version >= "3.10" and python_version < "4.0"
172
+ yarl==1.9.4 ; python_version >= "3.10" and python_version < "4.0"
173
+ zc-lockfile==3.0.post1 ; python_version >= "3.10" and python_version < "4.0"