betterdataai commited on
Commit
552c5cf
·
verified ·
1 Parent(s): 04c74e1

Delete inference.py

Browse files
Files changed (1) hide show
  1. inference.py +0 -121
inference.py DELETED
@@ -1,121 +0,0 @@
1
- import json
2
- from unsloth import FastLanguageModel
3
- from transformers import TextStreamer # if needed elsewhere
4
-
5
- # Set parameters
6
- max_seq_length = 4096
7
- dtype = None
8
- load_in_4bit = False
9
-
10
- # Load model and tokenizer once at startup
11
- model, tokenizer = FastLanguageModel.from_pretrained(
12
- model_name="betterdataai/large-tabular-model",
13
- max_seq_length=max_seq_length,
14
- dtype=dtype,
15
- load_in_4bit=load_in_4bit,
16
- )
17
- FastLanguageModel.for_inference(model)
18
-
19
- def prompt_transformation(prompt):
20
- initial_prompt = """
21
- We have the following natural language query:
22
- "{}"
23
-
24
- Transform the above natural language query into a formalized prompt format. The format should include:
25
-
26
- 1. A sentence summarizing the objective.
27
- 2. A description of the columns, including their data types and examples.
28
- 3. Four example rows of the dataset in CSV format.
29
-
30
- An example of this format is as follows, please only focus on the format, not the content:
31
-
32
- "You are tasked with generating a synthetic dataset based on the following description. The dataset represents employee information. The dataset should include the following columns:
33
-
34
- - NAME (String): Employee's full name, consisting of a first and last name (e.g., "John Doe", "Maria Lee", "Wei Zhang").
35
- - GENDER (String): Employee's gender (e.g., "Male", "Female").
36
- - EMAIL (String): Employee's email address, following the standard format.
37
- - CITY (String): City where the employee resides (e.g., "New York", "London", "Beijing").
38
- - COUNTRY (String): Country where the employee resides (e.g., "USA", "UK", "China").
39
- - SALARY (Float): Employee's annual salary, a value between 30000 and 150000 (e.g., 55000.0, 75000.0).
40
-
41
- Here are some examples:
42
- NAME,GENDER,EMAIL,CITY,COUNTRY,SALARY
43
- John Doe,Male,john.doe@example.com,New York,USA,56000.0
44
- Maria Lee,Female,maria.lee@nus.edu.sg,London,UK,72000.0
45
- Wei Zhang,Male,wei.zhang@meta.com,Beijing,China,65000.0
46
- Sara Smith,Female,sara.smith@orange.fr,Paris,France,85000.0"
47
-
48
- Here is the transformed query from the given natural language query:
49
- """
50
-
51
- messages = [
52
- {"role": "system", "content": initial_prompt.format(prompt)},
53
- {"role": "user", "content": "transform the given natural language text to the designated format"}
54
- ]
55
-
56
- inputs = tokenizer.apply_chat_template(
57
- messages,
58
- tokenize=True,
59
- add_generation_prompt=True, # Required for generation
60
- return_tensors="pt",
61
- ).to("cuda")
62
-
63
- output_ids = model.generate(
64
- input_ids=inputs,
65
- max_new_tokens=4096,
66
- use_cache=True,
67
- temperature=1.5,
68
- min_p=0.1
69
- )
70
-
71
- generated_ids = output_ids[0][inputs.shape[1]:]
72
- return tokenizer.decode(generated_ids, skip_special_tokens=True)
73
-
74
- def table_generation(prompt):
75
- messages = [
76
- {"role": "system", "content": prompt},
77
- {"role": "user", "content": "create 20 data rows"}
78
- ]
79
-
80
- inputs = tokenizer.apply_chat_template(
81
- messages,
82
- tokenize=True,
83
- add_generation_prompt=True, # Required for generation
84
- return_tensors="pt",
85
- ).to("cuda")
86
-
87
- output_ids = model.generate(
88
- input_ids=inputs,
89
- max_new_tokens=4096,
90
- use_cache=True,
91
- temperature=1.5,
92
- min_p=0.1
93
- )
94
-
95
- generated_ids = output_ids[0][inputs.shape[1]:]
96
- return tokenizer.decode(generated_ids, skip_special_tokens=True)
97
-
98
- def predict(input_data):
99
- """
100
- Inference endpoint entry point.
101
-
102
- Expects input_data as a JSON string or dict with a key "query" that contains the natural language query.
103
- Returns a JSON string with the generated table.
104
- """
105
- try:
106
- if isinstance(input_data, str):
107
- data = json.loads(input_data)
108
- else:
109
- data = input_data
110
- user_query = data.get("query", "")
111
- except Exception:
112
- return json.dumps({
113
- "error": "Invalid input format. Please provide a JSON payload with a 'query' field."
114
- })
115
-
116
- # Transform the user query into the desired prompt format
117
- transformed_prompt = prompt_transformation(user_query)
118
- # Generate the table using the transformed prompt
119
- generated_table = table_generation(transformed_prompt)
120
-
121
- return json.dumps({"result": generated_table})