tabular-llm-demo / inference.py
betterdataai's picture
Upload 2 files
0228f64 verified
import json
from unsloth import FastLanguageModel
from transformers import TextStreamer # if needed elsewhere
# Set parameters
max_seq_length = 4096
dtype = None
load_in_4bit = False
# Load model and tokenizer once at startup
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="betterdataai/large-tabular-model",
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
)
FastLanguageModel.for_inference(model)
def prompt_transformation(prompt):
initial_prompt = """
We have the following natural language query:
"{}"
Transform the above natural language query into a formalized prompt format. The format should include:
1. A sentence summarizing the objective.
2. A description of the columns, including their data types and examples.
3. Four example rows of the dataset in CSV format.
An example of this format is as follows, please only focus on the format, not the content:
"You are tasked with generating a synthetic dataset based on the following description. The dataset represents employee information. The dataset should include the following columns:
- NAME (String): Employee's full name, consisting of a first and last name (e.g., "John Doe", "Maria Lee", "Wei Zhang").
- GENDER (String): Employee's gender (e.g., "Male", "Female").
- EMAIL (String): Employee's email address, following the standard format.
- CITY (String): City where the employee resides (e.g., "New York", "London", "Beijing").
- COUNTRY (String): Country where the employee resides (e.g., "USA", "UK", "China").
- SALARY (Float): Employee's annual salary, a value between 30000 and 150000 (e.g., 55000.0, 75000.0).
Here are some examples:
NAME,GENDER,EMAIL,CITY,COUNTRY,SALARY
John Doe,Male,john.doe@example.com,New York,USA,56000.0
Maria Lee,Female,maria.lee@nus.edu.sg,London,UK,72000.0
Wei Zhang,Male,wei.zhang@meta.com,Beijing,China,65000.0
Sara Smith,Female,sara.smith@orange.fr,Paris,France,85000.0"
Here is the transformed query from the given natural language query:
"""
messages = [
{"role": "system", "content": initial_prompt.format(prompt)},
{"role": "user", "content": "transform the given natural language text to the designated format"}
]
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True, # Required for generation
return_tensors="pt",
).to("cuda")
output_ids = model.generate(
input_ids=inputs,
max_new_tokens=4096,
use_cache=True,
temperature=1.5,
min_p=0.1
)
generated_ids = output_ids[0][inputs.shape[1]:]
return tokenizer.decode(generated_ids, skip_special_tokens=True)
def table_generation(prompt):
messages = [
{"role": "system", "content": prompt},
{"role": "user", "content": "create 20 data rows"}
]
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True, # Required for generation
return_tensors="pt",
).to("cuda")
output_ids = model.generate(
input_ids=inputs,
max_new_tokens=4096,
use_cache=True,
temperature=1.5,
min_p=0.1
)
generated_ids = output_ids[0][inputs.shape[1]:]
return tokenizer.decode(generated_ids, skip_special_tokens=True)
def predict(input_data):
"""
Inference endpoint entry point.
Expects input_data as a JSON string or dict with a key "query" that contains the natural language query.
Returns a JSON string with the generated table.
"""
try:
if isinstance(input_data, str):
data = json.loads(input_data)
else:
data = input_data
user_query = data.get("query", "")
except Exception:
return json.dumps({
"error": "Invalid input format. Please provide a JSON payload with a 'query' field."
})
# Transform the user query into the desired prompt format
transformed_prompt = prompt_transformation(user_query)
# Generate the table using the transformed prompt
generated_table = table_generation(transformed_prompt)
return json.dumps({"result": generated_table})