File size: 4,491 Bytes
0228f64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import json
from unsloth import FastLanguageModel
from transformers import TextStreamer  # if needed elsewhere

# Set parameters
max_seq_length = 4096 
dtype = None 
load_in_4bit = False 

# Load model and tokenizer once at startup
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="betterdataai/large-tabular-model", 
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)
FastLanguageModel.for_inference(model)

def prompt_transformation(prompt):
    initial_prompt = """

    We have the following natural language query:

    "{}"

    

    Transform the above natural language query into a formalized prompt format. The format should include:

    

    1. A sentence summarizing the objective.

    2. A description of the columns, including their data types and examples.

    3. Four example rows of the dataset in CSV format.

    

    An example of this format is as follows, please only focus on the format, not the content:

    

    "You are tasked with generating a synthetic dataset based on the following description. The dataset represents employee information. The dataset should include the following columns:

    

    - NAME (String): Employee's full name, consisting of a first and last name (e.g., "John Doe", "Maria Lee", "Wei Zhang").

    - GENDER (String): Employee's gender (e.g., "Male", "Female").

    - EMAIL (String): Employee's email address, following the standard format.

    - CITY (String): City where the employee resides (e.g., "New York", "London", "Beijing").

    - COUNTRY (String): Country where the employee resides (e.g., "USA", "UK", "China").

    - SALARY (Float): Employee's annual salary, a value between 30000 and 150000 (e.g., 55000.0, 75000.0).

    

    Here are some examples:

    NAME,GENDER,EMAIL,CITY,COUNTRY,SALARY

    John Doe,Male,john.doe@example.com,New York,USA,56000.0

    Maria Lee,Female,maria.lee@nus.edu.sg,London,UK,72000.0

    Wei Zhang,Male,wei.zhang@meta.com,Beijing,China,65000.0

    Sara Smith,Female,sara.smith@orange.fr,Paris,France,85000.0"

    

    Here is the transformed query from the given natural language query:

    """
    
    messages = [
        {"role": "system", "content": initial_prompt.format(prompt)},
        {"role": "user", "content": "transform the given natural language text to the designated format"}
    ]
    
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,  # Required for generation
        return_tensors="pt",
    ).to("cuda")
    
    output_ids = model.generate(
        input_ids=inputs, 
        max_new_tokens=4096, 
        use_cache=True, 
        temperature=1.5, 
        min_p=0.1
    )
    
    generated_ids = output_ids[0][inputs.shape[1]:]
    return tokenizer.decode(generated_ids, skip_special_tokens=True)

def table_generation(prompt):
    messages = [
        {"role": "system", "content": prompt},
        {"role": "user", "content": "create 20 data rows"}
    ]
    
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,  # Required for generation
        return_tensors="pt",
    ).to("cuda")
    
    output_ids = model.generate(
        input_ids=inputs, 
        max_new_tokens=4096, 
        use_cache=True, 
        temperature=1.5, 
        min_p=0.1
    )
    
    generated_ids = output_ids[0][inputs.shape[1]:]
    return tokenizer.decode(generated_ids, skip_special_tokens=True)

def predict(input_data):
    """

    Inference endpoint entry point.

    

    Expects input_data as a JSON string or dict with a key "query" that contains the natural language query.

    Returns a JSON string with the generated table.

    """
    try:
        if isinstance(input_data, str):
            data = json.loads(input_data)
        else:
            data = input_data
        user_query = data.get("query", "")
    except Exception:
        return json.dumps({
            "error": "Invalid input format. Please provide a JSON payload with a 'query' field."
        })
    
    # Transform the user query into the desired prompt format
    transformed_prompt = prompt_transformation(user_query)
    # Generate the table using the transformed prompt
    generated_table = table_generation(transformed_prompt)
    
    return json.dumps({"result": generated_table})