|
|
|
|
|
|
|
|
|
|
|
import os, time |
|
|
import requests |
|
|
from typing import List, Dict, Tuple |
|
|
from datetime import datetime |
|
|
from anthropic import Anthropic |
|
|
from openai import OpenAI |
|
|
import gradio as gr |
|
|
from tqdm import tqdm |
|
|
|
|
|
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") |
|
|
assert ANTHROPIC_API_KEY, "Set ANTHROPIC_API_KEY in Space settings" |
|
|
|
|
|
VLLM_API = "http://localhost:8000/v1" |
|
|
|
|
|
QWEN_MODEL = "Qwen/Qwen1.5-4B-Chat-AWQ" |
|
|
CLAUDE_MODEL = "claude-3-5-haiku-latest" |
|
|
|
|
|
open_source_client = OpenAI(api_key="EMPTY", base_url=VLLM_API) |
|
|
claude_client = Anthropic(api_key=ANTHROPIC_API_KEY) |
|
|
|
|
|
|
|
|
def wait_for_vllm_ready(timeout=900): |
|
|
start = time.time() |
|
|
while time.time() - start < timeout: |
|
|
try: |
|
|
r = requests.get("http://localhost:8000/health", timeout=3) |
|
|
if r.status_code == 200: |
|
|
return True |
|
|
except Exception: |
|
|
pass |
|
|
time.sleep(2) |
|
|
raise RuntimeError("vLLM did not start within timeout") |
|
|
|
|
|
def invoke_messages( |
|
|
rows_num: int, |
|
|
business_category: str, |
|
|
columns: str, |
|
|
instruction: str, |
|
|
) -> List[Dict[str, str]]: |
|
|
system_message = """ |
|
|
You are a helpful assistant generating synthetic mockup dataset as per |
|
|
user's request across all types of businesses and sorts. |
|
|
User's specific request for the data niche, data column types, and all |
|
|
other details and your job is to create wonderful mockup data for them |
|
|
to use for their demo apps or develop in a testing environment. |
|
|
""".strip() |
|
|
|
|
|
user_prompt = f""" |
|
|
Generate a synthetic mockup data that fits the following instruction: |
|
|
- Number of rows: {rows_num} |
|
|
- Business area: {business_category} |
|
|
- Columns: {columns} |
|
|
- Other instruction: {instruction} |
|
|
ㅡ Make sure to deliver only the markdown content without any additional comments |
|
|
""".strip() |
|
|
|
|
|
system_message = system_message + """ |
|
|
In the case of sql file selection as an output, make sure to |
|
|
contain the full sql file format, including CREATE TABLE command. |
|
|
""".strip() |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": system_message}, |
|
|
{"role": "user", "content": user_prompt} |
|
|
] |
|
|
|
|
|
return messages |
|
|
|
|
|
|
|
|
def pass_claude_msg(file_format: str, content: str) -> Tuple[str, str]: |
|
|
claude_sys_msg = """ |
|
|
You are a helpful assistant, converting generated outputs (done by other model) |
|
|
into the format of chosen type: |
|
|
example: csv, sql, or json format. |
|
|
NOTE: generate the result output that only includes the markdown content |
|
|
without any addtional comments! |
|
|
""".strip() |
|
|
claude_user_msg = f""" |
|
|
Convert the output into the {file_format} format for the following content: |
|
|
---------------------------------------------------------------------- |
|
|
{content} |
|
|
""".strip() |
|
|
|
|
|
return claude_sys_msg, claude_user_msg |
|
|
|
|
|
|
|
|
def generate_output(messages): |
|
|
|
|
|
resp = open_source_client.chat.completions.create( |
|
|
model=QWEN_MODEL, |
|
|
messages=messages, |
|
|
max_tokens=400, |
|
|
temperature=0.2, |
|
|
stream=False |
|
|
) |
|
|
|
|
|
return resp.choices[0].message.content |
|
|
|
|
|
|
|
|
def launch_claude_api(sys_msg, user_msg): |
|
|
|
|
|
if not claude_client: |
|
|
return None |
|
|
|
|
|
response = claude_client.messages.create( |
|
|
model=CLAUDE_MODEL, |
|
|
system=sys_msg, |
|
|
max_tokens=400, |
|
|
temperature=0.1, |
|
|
messages=[ |
|
|
{"role": "user", "content": user_msg} |
|
|
] |
|
|
) |
|
|
return response.content[0].text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_mockup_data(category, num_data_rows, columns, a_instruction, |
|
|
progress=gr.Progress()): |
|
|
progress(0.2, desc="Generating...") |
|
|
msg = invoke_messages( |
|
|
rows_num=int(num_data_rows or 10), |
|
|
business_category=category, |
|
|
columns=columns, |
|
|
instruction=a_instruction |
|
|
) |
|
|
|
|
|
result = generate_output(msg) |
|
|
progress(1.0, desc="Done") |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def show_hidden_row(): |
|
|
return gr.update(visible=True) |
|
|
|
|
|
|
|
|
def make_file(btn_sort: str, category: str, content: str): |
|
|
''' |
|
|
btn_sort: one of the 3 download file tpes from the buttons -- download csv, sql, json |
|
|
category: Business category or area that the data is associated with. |
|
|
content: LLM generated text output to write in a file |
|
|
''' |
|
|
|
|
|
if not content or not content.strip(): |
|
|
raise gr.Error("The result content is empty. Cannot create a file.") |
|
|
|
|
|
if not claude_client: |
|
|
raise gr.Error("File formatting requires ANTHROPIC_API_KEY.") |
|
|
|
|
|
try: |
|
|
sys_msg, user_msg = pass_claude_msg(btn_sort, content) |
|
|
claude_output = launch_claude_api(sys_msg, user_msg) |
|
|
|
|
|
ts = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
filepath = f"/tmp/{category}_mockup_{ts}.{btn_sort}" |
|
|
|
|
|
with open(filepath, "w") as f: |
|
|
f.write(claude_output) |
|
|
|
|
|
return filepath |
|
|
except Exception as e: |
|
|
raise gr.Error("Failed to format or create the file.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def render_interface(): |
|
|
|
|
|
with gr.Blocks(title="Mockup Data Generator", css="footer {visibility:hidden}") as demo: |
|
|
category = gr.Textbox( |
|
|
label="Business Area/Category", |
|
|
placeholder="e.g. HR, Sales, Hospitality, Senior Care, E-commerce, Finance", |
|
|
) |
|
|
num_data_rows = gr.Number( |
|
|
label="Number of Rows", |
|
|
placeholder="Type number...", |
|
|
minimum=10, |
|
|
maximum=50, |
|
|
step=10, |
|
|
precision=0 |
|
|
) |
|
|
columns = gr.Textbox( |
|
|
label="Insert Columns", |
|
|
placeholder="Comma, separated..." |
|
|
) |
|
|
a_instruction = gr.Textbox( |
|
|
label="Additional Instruction", |
|
|
placeholder="Any additional instruction. Leave blank if none.", |
|
|
lines=5 |
|
|
) |
|
|
btn = gr.Button( |
|
|
value="Generate" |
|
|
) |
|
|
out = gr.Textbox(label="Result shown here.") |
|
|
|
|
|
buttons_row = gr.Row(visible=False) |
|
|
|
|
|
with buttons_row: |
|
|
btn_csv = gr.DownloadButton(label="Download csv", size="md", elem_classes=["download-btn"]) |
|
|
btn_sql = gr.DownloadButton(label="Download sql", size="md", elem_classes=["download-btn"]) |
|
|
btn_json = gr.DownloadButton(label="Download json", size="md", elem_classes=["download-btn"]) |
|
|
|
|
|
chain = btn.click( |
|
|
fn=generate_mockup_data, |
|
|
inputs=[category, num_data_rows, columns, a_instruction], |
|
|
outputs=out, |
|
|
queue=True |
|
|
) |
|
|
|
|
|
chain = chain.then( |
|
|
fn=show_hidden_row, |
|
|
inputs=None, |
|
|
outputs=buttons_row, |
|
|
) |
|
|
|
|
|
btn_csv.click( |
|
|
lambda category, data: make_file("csv", category, data), |
|
|
inputs=[category, out], |
|
|
outputs=btn_csv |
|
|
) |
|
|
|
|
|
btn_sql.click( |
|
|
lambda category, data: make_file("sql", category, data), |
|
|
inputs=[category, out], |
|
|
outputs=btn_sql |
|
|
) |
|
|
|
|
|
btn_json.click( |
|
|
lambda category, data: make_file("json", category, data), |
|
|
inputs=[category, out], |
|
|
outputs=btn_json |
|
|
) |
|
|
return demo |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
wait_for_vllm_ready(900) |
|
|
app = render_interface() |
|
|
app.queue(default_concurrency_limit=1) |
|
|
app.launch(server_name="0.0.0.0", server_port=7860) |