MightyOctopus commited on
Commit
615ce65
·
verified ·
1 Parent(s): 91c414a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +255 -0
app.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##########====================================================================################
2
+ ##########====================PRODUCTION VERSION -- vLLM, GRADIO=====================###########
3
+ ##########====================================================================################
4
+ import os
5
+ import requests
6
+ from typing import List, Dict, Tuple
7
+ from datetime import datetime
8
+ from anthropic import Anthropic
9
+ from openai import OpenAI
10
+ import time, gradio as gr
11
+ from tqdm import tqdm
12
+
13
+ ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
14
+ assert ANTHROPIC_API_KEY, "Set ANTHROPIC_API_KEY in Space settings"
15
+
16
+ VLLM_API = "http://localhost:8000/v1"
17
+
18
+ QWEN_MODEL = "Qwen/Qwen3-4B-Instruct-2507"
19
+ CLAUDE_MODEL = "claude-3-5-haiku-latest"
20
+
21
+ open_source_client = OpenAI(api_key="EMPTY", base_url=VLLM_API)
22
+ claude_client = Anthropic(api_key=ANTHROPIC_API_KEY)
23
+
24
+
25
+ def invoke_messages(
26
+ rows_num: int,
27
+ business_category: str,
28
+ columns: str,
29
+ instruction: str,
30
+ ) -> List[Dict[str, str]]:
31
+ system_message = """
32
+ You are a helpful assistant generating synthetic mockup dataset as per
33
+ user's request across all types of businesses and sorts.
34
+ User's specific request for the data niche, data column types, and all
35
+ other details and your job is to create wonderful mockup data for them
36
+ to use for their demo apps or develop in a testing environment.
37
+ """.strip()
38
+
39
+ user_prompt = f"""
40
+ Generate a synthetic mockup data that fits the following instruction:
41
+ - Number of rows: {rows_num}
42
+ - Business area: {business_category}
43
+ - Columns: {columns}
44
+ - Other instruction: {instruction}
45
+ ㅡ Make sure to deliver only the markdown content without any additional comments
46
+ """.strip()
47
+
48
+ system_message = system_message + """
49
+ In the case of sql file selection as an output, make sure to
50
+ contain the full sql file format, including CREATE TABLE command.
51
+ """.strip()
52
+
53
+ messages = [
54
+ {"role": "system", "content": system_message},
55
+ {"role": "user", "content": user_prompt}
56
+ ]
57
+
58
+ return messages
59
+
60
+
61
+ def pass_claude_msg(file_format: str, content: str) -> Tuple[str, str]:
62
+ claude_sys_msg = """
63
+ You are a helpful assistant, converting generated outputs (done by other model)
64
+ into the format of chosen type:
65
+ example: csv, sql, or json format.
66
+ NOTE: generate the result output that only includes the markdown content
67
+ without any addtional comments!
68
+ """.strip()
69
+ claude_user_msg = f"""
70
+ Convert the output into the {file_format} format for the following content:
71
+ ----------------------------------------------------------------------
72
+ {content}
73
+ """.strip()
74
+
75
+ return claude_sys_msg, claude_user_msg
76
+
77
+
78
+ def generate_output(messages):
79
+ enable_model()
80
+
81
+ inputs = tokenizer.apply_chat_template(
82
+ messages,
83
+ return_tensors="pt",
84
+ return_dict=True, ### IMPORTANT: to get a mapping
85
+ tokenize=True,
86
+ add_generation_prompt=True,
87
+ padding=True,
88
+ return_attention_mask=True
89
+ ).to(model.device)
90
+
91
+ # print(inputs)
92
+
93
+ outputs = model.generate(
94
+ **inputs,
95
+ max_new_tokens=400,
96
+ temperature=0.2
97
+ )
98
+
99
+ ### Get the length(num of tokens) of the input prompt
100
+ prompt_len = inputs["input_ids"].shape[1]
101
+
102
+ ### Slice the generated sequence to skip the prompt length
103
+ gen_tokens = outputs[0][prompt_len:]
104
+
105
+ # print(tokenizer.decode(gen_tokens, skip_special_tokens=True))
106
+
107
+ return gen_tokens
108
+
109
+
110
+ def launch_claude_api(sys_msg, user_msg):
111
+ response = claude.messages.create(
112
+ model=CLAUDE_MODEL,
113
+ system=sys_msg,
114
+ max_tokens=400,
115
+ temperature=0.1,
116
+ messages=[
117
+ {"role": "user", "content": user_msg}
118
+ ]
119
+ )
120
+ return response.content[0].text
121
+
122
+
123
+ ###============= Gradio Function =============###
124
+
125
+ def generate_mockup_data(category, num_data_rows, columns, a_instruction,
126
+ progress=gr.Progress()):
127
+ progress(0.2, desc="Generating...")
128
+ msg = invoke_messages(
129
+ rows_num=int(num_data_rows or 10),
130
+ business_category=category,
131
+ columns=columns,
132
+ instruction=a_instruction
133
+ )
134
+
135
+ resp = open_source_client.chat.completions.create(
136
+ model=QWEN_MODEL,
137
+ messages=msg,
138
+ max_tokens=400,
139
+ temperature=0.2,
140
+ stream=False
141
+ )
142
+ progress(1.0, desc="Done")
143
+
144
+ return resp.choices[0].message.content
145
+
146
+
147
+ def show_hidden_row():
148
+ return gr.update(visible=True)
149
+
150
+
151
+ def make_file(btn_sort: str, category: str, content: str):
152
+ '''
153
+ btn_sort: one of the 3 download file tpes from the buttons -- download csv, sql, json
154
+ category: Business category or area that the data is associated with.
155
+ content: LLM generated text output to write in a file
156
+ '''
157
+
158
+ if not content or not content.strip():
159
+ raise gr.Error("The result content is empty. Cannot create a file.")
160
+
161
+ try:
162
+ sys_msg, user_msg = pass_claude_msg(btn_sort, content)
163
+ claude_output = launch_claude_api(sys_msg, user_msg)
164
+
165
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
166
+ filepath = f"/tmp/{category}_mockup_{ts}.{btn_sort}"
167
+
168
+ with open(filepath, "w") as f:
169
+ f.write(claude_output)
170
+
171
+ return filepath
172
+ except Exception as e:
173
+ raise gr.Error("Failed to format or create the file.")
174
+
175
+
176
+ ###============= Gradio UI =============###
177
+
178
+ def render_interface():
179
+
180
+ with gr.Blocks(title="Mockup Data Generator", css="footer {visibility:hidden}") as demo:
181
+ category = gr.Textbox(
182
+ label="Business Area/Category",
183
+ placeholder="e.g. HR, Sales, Hospitality, Senior Care, E-commerce, Finance",
184
+ )
185
+ num_data_rows = gr.Number(
186
+ label="Number of Rows",
187
+ placeholder="Type number...",
188
+ minimum=10,
189
+ maximum=50,
190
+ step=10,
191
+ precision=0
192
+ )
193
+ columns = gr.Textbox(
194
+ label="Insert Columns",
195
+ placeholder="Comma, separated..."
196
+ )
197
+ a_instruction = gr.Textbox(
198
+ label="Additional Instruction",
199
+ placeholder="Any additional instruction. Leave blank if none.",
200
+ lines=5
201
+ )
202
+ btn = gr.Button(
203
+ value="Generate"
204
+ )
205
+ out = gr.Textbox(label="Result shown here.")
206
+
207
+ buttons_row = gr.Row(visible=False)
208
+
209
+ with buttons_row:
210
+ btn_csv = gr.DownloadButton(label="Download csv", size="md", elem_classes=["download-btn"])
211
+ btn_sql = gr.DownloadButton(label="Download sql", size="md", elem_classes=["download-btn"])
212
+ btn_json = gr.DownloadButton(label="Download json", size="md", elem_classes=["download-btn"])
213
+
214
+ chain = btn.click(
215
+ fn=generate_mockup_data,
216
+ inputs=[category, num_data_rows, columns, a_instruction],
217
+ outputs=out,
218
+ queue=True
219
+ )
220
+
221
+ chain = chain.then(
222
+ fn=show_hidden_row,
223
+ inputs=None,
224
+ outputs=buttons_row,
225
+ )
226
+
227
+ btn_csv.click(
228
+ lambda category, data: make_file("csv", category, data),
229
+ inputs=[category, out],
230
+ outputs=btn_csv
231
+ )
232
+
233
+ btn_sql.click(
234
+ lambda category, data: make_file("sql", category, data),
235
+ inputs=[category, out],
236
+ outputs=btn_sql
237
+ )
238
+
239
+ btn_json.click(
240
+ lambda category, data: make_file("json", category, data),
241
+ inputs=[category, out],
242
+ outputs=btn_json
243
+ )
244
+
245
+ ### Pre-warming the model right upon the page load
246
+ ### in order to save the model load time when user submitting the form.
247
+ demo.load(lambda: enable_model(), queue=False)
248
+
249
+ return demo
250
+
251
+
252
+ if __name__ == "__main__":
253
+ app = render_interface()
254
+ app.queue(default_concurrency_limit=1)
255
+ app.launch(server_name="0.0.0.0", server_port=7860)