| import json
|
| from unsloth import FastLanguageModel
|
| from transformers import TextStreamer
|
|
|
|
|
| max_seq_length = 4096
|
| dtype = None
|
| load_in_4bit = False
|
|
|
|
|
| model, tokenizer = FastLanguageModel.from_pretrained(
|
| model_name="betterdataai/large-tabular-model",
|
| max_seq_length=max_seq_length,
|
| dtype=dtype,
|
| load_in_4bit=load_in_4bit,
|
| )
|
| FastLanguageModel.for_inference(model)
|
|
|
| def prompt_transformation(prompt):
|
| initial_prompt = """
|
| We have the following natural language query:
|
| "{}"
|
|
|
| Transform the above natural language query into a formalized prompt format. The format should include:
|
|
|
| 1. A sentence summarizing the objective.
|
| 2. A description of the columns, including their data types and examples.
|
| 3. Four example rows of the dataset in CSV format.
|
|
|
| An example of this format is as follows, please only focus on the format, not the content:
|
|
|
| "You are tasked with generating a synthetic dataset based on the following description. The dataset represents employee information. The dataset should include the following columns:
|
|
|
| - NAME (String): Employee's full name, consisting of a first and last name (e.g., "John Doe", "Maria Lee", "Wei Zhang").
|
| - GENDER (String): Employee's gender (e.g., "Male", "Female").
|
| - EMAIL (String): Employee's email address, following the standard format.
|
| - CITY (String): City where the employee resides (e.g., "New York", "London", "Beijing").
|
| - COUNTRY (String): Country where the employee resides (e.g., "USA", "UK", "China").
|
| - SALARY (Float): Employee's annual salary, a value between 30000 and 150000 (e.g., 55000.0, 75000.0).
|
|
|
| Here are some examples:
|
| NAME,GENDER,EMAIL,CITY,COUNTRY,SALARY
|
| John Doe,Male,john.doe@example.com,New York,USA,56000.0
|
| Maria Lee,Female,maria.lee@nus.edu.sg,London,UK,72000.0
|
| Wei Zhang,Male,wei.zhang@meta.com,Beijing,China,65000.0
|
| Sara Smith,Female,sara.smith@orange.fr,Paris,France,85000.0"
|
|
|
| Here is the transformed query from the given natural language query:
|
| """
|
|
|
| messages = [
|
| {"role": "system", "content": initial_prompt.format(prompt)},
|
| {"role": "user", "content": "transform the given natural language text to the designated format"}
|
| ]
|
|
|
| inputs = tokenizer.apply_chat_template(
|
| messages,
|
| tokenize=True,
|
| add_generation_prompt=True,
|
| return_tensors="pt",
|
| ).to("cuda")
|
|
|
| output_ids = model.generate(
|
| input_ids=inputs,
|
| max_new_tokens=4096,
|
| use_cache=True,
|
| temperature=1.5,
|
| min_p=0.1
|
| )
|
|
|
| generated_ids = output_ids[0][inputs.shape[1]:]
|
| return tokenizer.decode(generated_ids, skip_special_tokens=True)
|
|
|
| def table_generation(prompt):
|
| messages = [
|
| {"role": "system", "content": prompt},
|
| {"role": "user", "content": "create 20 data rows"}
|
| ]
|
|
|
| inputs = tokenizer.apply_chat_template(
|
| messages,
|
| tokenize=True,
|
| add_generation_prompt=True,
|
| return_tensors="pt",
|
| ).to("cuda")
|
|
|
| output_ids = model.generate(
|
| input_ids=inputs,
|
| max_new_tokens=4096,
|
| use_cache=True,
|
| temperature=1.5,
|
| min_p=0.1
|
| )
|
|
|
| generated_ids = output_ids[0][inputs.shape[1]:]
|
| return tokenizer.decode(generated_ids, skip_special_tokens=True)
|
|
|
| def predict(input_data):
|
| """
|
| Inference endpoint entry point.
|
|
|
| Expects input_data as a JSON string or dict with a key "query" that contains the natural language query.
|
| Returns a JSON string with the generated table.
|
| """
|
| try:
|
| if isinstance(input_data, str):
|
| data = json.loads(input_data)
|
| else:
|
| data = input_data
|
| user_query = data.get("query", "")
|
| except Exception:
|
| return json.dumps({
|
| "error": "Invalid input format. Please provide a JSON payload with a 'query' field."
|
| })
|
|
|
|
|
| transformed_prompt = prompt_transformation(user_query)
|
|
|
| generated_table = table_generation(transformed_prompt)
|
|
|
| return json.dumps({"result": generated_table})
|
|
|