betterdataai commited on
Commit
aae3c09
·
verified ·
1 Parent(s): 39860ea

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +128 -76
handler.py CHANGED
@@ -1,86 +1,138 @@
1
- import os
2
  import json
 
3
  import torch
4
- from peft import PeftModel
5
- from transformers import (
6
- LlamaForCausalLM,
7
- LlamaTokenizer,
8
- GenerationConfig,
9
- )
10
 
11
- class EndpointHandler:
12
- def __init__(self, model_dir: str = ".", **kwargs):
13
- """
14
- This method runs once when the Endpoint first starts.
15
- - model_dir is the local directory of *this* repository
16
- which contains your LoRA adapter weights (e.g. adapter_model.safetensors).
17
- """
18
 
19
- # 1) Base model from Hugging Face
20
- # Make sure to use the EXACT base you trained on, or it won't match your LoRA.
21
- self.base_model_id = "unsloth/Llama-3.2-3B-Instruct"
22
-
23
- # If your base model is gated/private, you'll need a token:
24
- # hf_token = os.getenv("HF_TOKEN", None)
25
-
26
- # 2) Load the tokenizer
27
- self.tokenizer = LlamaTokenizer.from_pretrained(
28
- self.base_model_id,
29
- trust_remote_code=True,
30
- # use_auth_token=hf_token, # if needed
31
- )
32
-
33
- # 3) Load the base model
34
- self.base_model = LlamaForCausalLM.from_pretrained(
35
- self.base_model_id,
36
- device_map="auto", # or "cuda:0"
37
- torch_dtype=torch.float16, # or bfloat16
38
- trust_remote_code=True,
39
- # use_auth_token=hf_token, # if needed
40
- )
41
 
42
- # 4) Load/merge your LoRA adapter
43
- self.model = PeftModel.from_pretrained(
44
- self.base_model,
45
- model_dir, # The local directory of this repo
46
- torch_dtype=torch.float16,
47
- ).eval()
48
-
49
- def __call__(self, data):
50
- """
51
- This method is called for every request to the endpoint.
52
- `data` is a dictionary (or JSON string) containing user inputs.
53
- Returns a dictionary or string (will be serialized as JSON).
54
- """
55
- # If data is a JSON string, parse it:
56
- if isinstance(data, str):
57
- data = json.loads(data)
58
-
59
- # Extract the user prompt from the request payload
60
- prompt = data.get("inputs", "")
61
- if not isinstance(prompt, str):
62
- raise ValueError("`inputs` must be a string.")
63
 
64
- # Optionally extract generation params (max_new_tokens, temperature, etc.)
65
- # If none provided, use defaults:
66
- gen_params = data.get("parameters", {})
67
- generation_config = GenerationConfig(
68
- max_new_tokens=gen_params.get("max_new_tokens", 128),
69
- temperature=gen_params.get("temperature", 0.7),
70
- top_p=gen_params.get("top_p", 0.9),
71
- do_sample=gen_params.get("do_sample", True),
72
- # etc.
73
- )
74
 
75
- # Tokenize the prompt
76
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
77
 
78
- # Generate text
79
- with torch.no_grad():
80
- output_ids = self.model.generate(**inputs, generation_config=generation_config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- # Decode the output
83
- output_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- # Return the generated text in a JSON-friendly format
86
- return {"generated_text": output_text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
+ import os
3
  import torch
4
+ from transformers import TextStreamer # if needed elsewhere
5
+ from unsloth import FastLanguageModel # Assumes FastLanguageModel supports loading a base model
6
+ from peft import PeftModel # For loading the adapter onto the base model
 
 
 
7
 
8
+ # Set parameters
9
+ max_seq_length = 4096
10
+ dtype = None
11
+ load_in_4bit = False
 
 
 
12
 
13
+ # Define the base model identifier (the full model from Hugging Face Hub)
14
+ base_model_id = "meta-llama/Llama-3.2-3B-Instruct"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # 1. Load the base model and tokenizer from the Hub.
17
+ # (This downloads the complete base model with all weights.)
18
+ base_model, tokenizer = FastLanguageModel.from_pretrained(
19
+ model_name=base_model_id,
20
+ max_seq_length=max_seq_length,
21
+ dtype=dtype,
22
+ load_in_4bit=load_in_4bit,
23
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # 2. Load your LoRA adapter weights from your repository.
26
+ # Here, "betterdataai/large-tabular-model" should be the local directory or identifier where the adapter weights reside.
27
+ # Ensure that this path contains the adapter weights (e.g. adapter_model.safetensors) and configuration.
28
+ model = PeftModel.from_pretrained(
29
+ base_model,
30
+ "betterdataai/large-tabular-model", # Path to your adapter weights
31
+ torch_dtype=torch.float16,
32
+ ).eval()
 
 
33
 
34
+ # 3. Prepare the merged model for inference.
35
+ FastLanguageModel.for_inference(model)
36
 
37
+ def prompt_transformation(prompt):
38
+ initial_prompt = """
39
+ We have the following natural language query:
40
+ "{}"
41
+
42
+ Transform the above natural language query into a formalized prompt format. The format should include:
43
+
44
+ 1. A sentence summarizing the objective.
45
+ 2. A description of the columns, including their data types and examples.
46
+ 3. Four example rows of the dataset in CSV format.
47
+
48
+ An example of this format is as follows, please only focus on the format, not the content:
49
+
50
+ "You are tasked with generating a synthetic dataset based on the following description. The dataset represents employee information. The dataset should include the following columns:
51
+
52
+ - NAME (String): Employee's full name, consisting of a first and last name (e.g., "John Doe", "Maria Lee", "Wei Zhang").
53
+ - GENDER (String): Employee's gender (e.g., "Male", "Female").
54
+ - EMAIL (String): Employee's email address, following the standard format.
55
+ - CITY (String): City where the employee resides (e.g., "New York", "London", "Beijing").
56
+ - COUNTRY (String): Country where the employee resides (e.g., "USA", "UK", "China").
57
+ - SALARY (Float): Employee's annual salary, a value between 30000 and 150000 (e.g., 55000.0, 75000.0).
58
+
59
+ Here are some examples:
60
+ NAME,GENDER,EMAIL,CITY,COUNTRY,SALARY
61
+ John Doe,Male,john.doe@example.com,New York,USA,56000.0
62
+ Maria Lee,Female,maria.lee@nus.edu.sg,London,UK,72000.0
63
+ Wei Zhang,Male,wei.zhang@meta.com,Beijing,China,65000.0
64
+ Sara Smith,Female,sara.smith@orange.fr,Paris,France,85000.0"
65
+
66
+ Here is the transformed query from the given natural language query:
67
+ """
68
+ messages = [
69
+ {"role": "system", "content": initial_prompt.format(prompt)},
70
+ {"role": "user", "content": "transform the given natural language text to the designated format"}
71
+ ]
72
+
73
+ inputs = tokenizer.apply_chat_template(
74
+ messages,
75
+ tokenize=True,
76
+ add_generation_prompt=True, # Required for generation
77
+ return_tensors="pt",
78
+ ).to("cuda")
79
+
80
+ output_ids = model.generate(
81
+ input_ids=inputs,
82
+ max_new_tokens=4096,
83
+ use_cache=True,
84
+ temperature=1.5,
85
+ min_p=0.1
86
+ )
87
+
88
+ generated_ids = output_ids[0][inputs.shape[1]:]
89
+ return tokenizer.decode(generated_ids, skip_special_tokens=True)
90
 
91
+ def table_generation(prompt):
92
+ messages = [
93
+ {"role": "system", "content": prompt},
94
+ {"role": "user", "content": "create 20 data rows"}
95
+ ]
96
+
97
+ inputs = tokenizer.apply_chat_template(
98
+ messages,
99
+ tokenize=True,
100
+ add_generation_prompt=True, # Required for generation
101
+ return_tensors="pt",
102
+ ).to("cuda")
103
+
104
+ output_ids = model.generate(
105
+ input_ids=inputs,
106
+ max_new_tokens=4096,
107
+ use_cache=True,
108
+ temperature=1.5,
109
+ min_p=0.1
110
+ )
111
+
112
+ generated_ids = output_ids[0][inputs.shape[1]:]
113
+ return tokenizer.decode(generated_ids, skip_special_tokens=True)
114
 
115
+ def predict(input_data):
116
+ """
117
+ Inference endpoint entry point.
118
+
119
+ Expects input_data as a JSON string or dict with a key "query" that contains the natural language query.
120
+ Returns a JSON string with the generated table.
121
+ """
122
+ try:
123
+ if isinstance(input_data, str):
124
+ data = json.loads(input_data)
125
+ else:
126
+ data = input_data
127
+ user_query = data.get("query", "")
128
+ except Exception:
129
+ return json.dumps({
130
+ "error": "Invalid input format. Please provide a JSON payload with a 'query' field."
131
+ })
132
+
133
+ # Transform the user query into the designated prompt format.
134
+ transformed_prompt = prompt_transformation(user_query)
135
+ # Generate the table using the transformed prompt.
136
+ generated_table = table_generation(transformed_prompt)
137
+
138
+ return json.dumps({"result": generated_table})