|
|
import json
|
|
|
import os
|
|
|
import torch
|
|
|
from transformers import TextStreamer
|
|
|
from unsloth import FastLanguageModel
|
|
|
from peft import PeftModel
|
|
|
|
|
|
|
|
|
max_seq_length = 4096
|
|
|
dtype = None
|
|
|
load_in_4bit = False
|
|
|
|
|
|
|
|
|
base_model_id = "unsloth/Llama-3.2-3B-Instruct"
|
|
|
|
|
|
|
|
|
|
|
|
base_model, tokenizer = FastLanguageModel.from_pretrained(
|
|
|
model_name=base_model_id,
|
|
|
max_seq_length=max_seq_length,
|
|
|
dtype=dtype,
|
|
|
load_in_4bit=load_in_4bit,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = PeftModel.from_pretrained(
|
|
|
base_model,
|
|
|
"betterdataai/large-tabular-model",
|
|
|
torch_dtype=torch.float16,
|
|
|
).eval()
|
|
|
|
|
|
|
|
|
FastLanguageModel.for_inference(model)
|
|
|
|
|
|
def prompt_transformation(prompt):
|
|
|
initial_prompt = """
|
|
|
We have the following natural language query:
|
|
|
"{}"
|
|
|
|
|
|
Transform the above natural language query into a formalized prompt format. The format should include:
|
|
|
|
|
|
1. A sentence summarizing the objective.
|
|
|
2. A description of the columns, including their data types and examples.
|
|
|
3. Four example rows of the dataset in CSV format.
|
|
|
|
|
|
An example of this format is as follows, please only focus on the format, not the content:
|
|
|
|
|
|
"You are tasked with generating a synthetic dataset based on the following description. The dataset represents employee information. The dataset should include the following columns:
|
|
|
|
|
|
- NAME (String): Employee's full name, consisting of a first and last name (e.g., "John Doe", "Maria Lee", "Wei Zhang").
|
|
|
- GENDER (String): Employee's gender (e.g., "Male", "Female").
|
|
|
- EMAIL (String): Employee's email address, following the standard format.
|
|
|
- CITY (String): City where the employee resides (e.g., "New York", "London", "Beijing").
|
|
|
- COUNTRY (String): Country where the employee resides (e.g., "USA", "UK", "China").
|
|
|
- SALARY (Float): Employee's annual salary, a value between 30000 and 150000 (e.g., 55000.0, 75000.0).
|
|
|
|
|
|
Here are some examples:
|
|
|
NAME,GENDER,EMAIL,CITY,COUNTRY,SALARY
|
|
|
John Doe,Male,john.doe@example.com,New York,USA,56000.0
|
|
|
Maria Lee,Female,maria.lee@nus.edu.sg,London,UK,72000.0
|
|
|
Wei Zhang,Male,wei.zhang@meta.com,Beijing,China,65000.0
|
|
|
Sara Smith,Female,sara.smith@orange.fr,Paris,France,85000.0"
|
|
|
|
|
|
Here is the transformed query from the given natural language query:
|
|
|
"""
|
|
|
messages = [
|
|
|
{"role": "system", "content": initial_prompt.format(prompt)},
|
|
|
{"role": "user", "content": "transform the given natural language text to the designated format"}
|
|
|
]
|
|
|
|
|
|
inputs = tokenizer.apply_chat_template(
|
|
|
messages,
|
|
|
tokenize=True,
|
|
|
add_generation_prompt=True,
|
|
|
return_tensors="pt",
|
|
|
).to("cuda")
|
|
|
|
|
|
output_ids = model.generate(
|
|
|
input_ids=inputs,
|
|
|
max_new_tokens=4096,
|
|
|
use_cache=True,
|
|
|
temperature=1.5,
|
|
|
min_p=0.1
|
|
|
)
|
|
|
|
|
|
generated_ids = output_ids[0][inputs.shape[1]:]
|
|
|
return tokenizer.decode(generated_ids, skip_special_tokens=True)
|
|
|
|
|
|
def table_generation(prompt):
|
|
|
messages = [
|
|
|
{"role": "system", "content": prompt},
|
|
|
{"role": "user", "content": "create 20 data rows"}
|
|
|
]
|
|
|
|
|
|
inputs = tokenizer.apply_chat_template(
|
|
|
messages,
|
|
|
tokenize=True,
|
|
|
add_generation_prompt=True,
|
|
|
return_tensors="pt",
|
|
|
).to("cuda")
|
|
|
|
|
|
output_ids = model.generate(
|
|
|
input_ids=inputs,
|
|
|
max_new_tokens=4096,
|
|
|
use_cache=True,
|
|
|
temperature=1.5,
|
|
|
min_p=0.1
|
|
|
)
|
|
|
|
|
|
generated_ids = output_ids[0][inputs.shape[1]:]
|
|
|
return tokenizer.decode(generated_ids, skip_special_tokens=True)
|
|
|
|
|
|
def predict(input_data):
|
|
|
"""
|
|
|
Inference endpoint entry point.
|
|
|
|
|
|
Expects input_data as a JSON string or dict with a key "query" that contains the natural language query.
|
|
|
Returns a JSON string with the generated table.
|
|
|
"""
|
|
|
try:
|
|
|
if isinstance(input_data, str):
|
|
|
data = json.loads(input_data)
|
|
|
else:
|
|
|
data = input_data
|
|
|
user_query = data.get("query", "")
|
|
|
except Exception:
|
|
|
return json.dumps({
|
|
|
"error": "Invalid input format. Please provide a JSON payload with a 'query' field."
|
|
|
})
|
|
|
|
|
|
|
|
|
transformed_prompt = prompt_transformation(user_query)
|
|
|
|
|
|
generated_table = table_generation(transformed_prompt)
|
|
|
|
|
|
return json.dumps({"result": generated_table})
|
|
|
|