Update app.py
Browse files
app.py
CHANGED
|
@@ -27,13 +27,17 @@ if not hf_token:
|
|
| 27 |
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
|
| 28 |
|
| 29 |
# Load GPT-2 model and tokenizer
|
| 30 |
-
tokenizer_gpt2 =
|
| 31 |
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
|
| 32 |
|
| 33 |
-
|
| 34 |
# Create a pipeline for text generation using GPT-2
|
| 35 |
text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
# Define prompt template
|
| 38 |
prompt_template = """\
|
| 39 |
You are an expert in generating synthetic data for machine learning models.
|
|
@@ -62,12 +66,14 @@ Columns:
|
|
| 62 |
{columns}
|
| 63 |
Output: """
|
| 64 |
|
| 65 |
-
# Set up the Mixtral model and tokenizer
|
| 66 |
-
token = os.getenv("HF_TOKEN")
|
| 67 |
-
HfFolder.save_token(token)
|
| 68 |
|
| 69 |
tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=token)
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
|
| 72 |
|
| 73 |
generation_params = {
|
|
@@ -78,15 +84,6 @@ generation_params = {
|
|
| 78 |
"use_cache": False
|
| 79 |
}
|
| 80 |
|
| 81 |
-
def preprocess_user_prompt(user_prompt):
|
| 82 |
-
generated_text = text_generator(user_prompt, max_length=50, num_return_sequences=1)[0]["generated_text"]
|
| 83 |
-
return generated_text
|
| 84 |
-
|
| 85 |
-
def format_prompt(description, columns):
|
| 86 |
-
processed_description = preprocess_user_prompt(description)
|
| 87 |
-
prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
|
| 88 |
-
return prompt
|
| 89 |
-
|
| 90 |
def generate_synthetic_data(description, columns):
|
| 91 |
formatted_prompt = format_prompt(description, columns)
|
| 92 |
payload = {"inputs": formatted_prompt, "parameters": generation_params}
|
|
@@ -95,12 +92,18 @@ def generate_synthetic_data(description, columns):
|
|
| 95 |
|
| 96 |
def process_generated_data(csv_data, expected_columns):
|
| 97 |
try:
|
|
|
|
| 98 |
cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
|
| 99 |
data = StringIO(cleaned_data)
|
|
|
|
|
|
|
| 100 |
df = pd.read_csv(data, delimiter=',')
|
|
|
|
|
|
|
| 101 |
if set(df.columns) != set(expected_columns):
|
| 102 |
print(f"Unexpected columns in the generated data: {df.columns}")
|
| 103 |
return None
|
|
|
|
| 104 |
return df
|
| 105 |
except pd.errors.ParserError as e:
|
| 106 |
print(f"Failed to parse CSV data: {e}")
|
|
@@ -108,19 +111,24 @@ def process_generated_data(csv_data, expected_columns):
|
|
| 108 |
|
| 109 |
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
|
| 110 |
data_frames = []
|
|
|
|
| 111 |
for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
|
| 112 |
generated_data = generate_synthetic_data(description, columns)
|
| 113 |
df_synthetic = process_generated_data(generated_data, columns)
|
|
|
|
| 114 |
if df_synthetic is not None and not df_synthetic.empty:
|
| 115 |
data_frames.append(df_synthetic)
|
| 116 |
else:
|
| 117 |
print("Skipping invalid generation.")
|
|
|
|
| 118 |
if data_frames:
|
| 119 |
return pd.concat(data_frames, ignore_index=True)
|
| 120 |
else:
|
| 121 |
print("No valid data frames to concatenate.")
|
| 122 |
return pd.DataFrame(columns=columns)
|
| 123 |
|
|
|
|
|
|
|
| 124 |
@app.route('/generate', methods=['POST'])
|
| 125 |
def generate():
|
| 126 |
data = request.json
|
|
|
|
| 27 |
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
|
| 28 |
|
| 29 |
# Load GPT-2 model and tokenizer
|
| 30 |
+
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')
|
| 31 |
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
|
| 32 |
|
|
|
|
| 33 |
# Create a pipeline for text generation using GPT-2
|
| 34 |
text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
|
| 35 |
|
| 36 |
+
def preprocess_user_prompt(user_prompt):
|
| 37 |
+
# Generate a structured prompt based on the user input
|
| 38 |
+
generated_text = text_generator(user_prompt, max_length=50, num_return_sequences=1)[0]["generated_text"]
|
| 39 |
+
return generated_text
|
| 40 |
+
|
| 41 |
# Define prompt template
|
| 42 |
prompt_template = """\
|
| 43 |
You are an expert in generating synthetic data for machine learning models.
|
|
|
|
| 66 |
{columns}
|
| 67 |
Output: """
|
| 68 |
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=token)
|
| 71 |
|
| 72 |
+
def format_prompt(description, columns):
|
| 73 |
+
processed_description = preprocess_user_prompt(description)
|
| 74 |
+
prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
|
| 75 |
+
return prompt
|
| 76 |
+
|
| 77 |
API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
|
| 78 |
|
| 79 |
generation_params = {
|
|
|
|
| 84 |
"use_cache": False
|
| 85 |
}
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
def generate_synthetic_data(description, columns):
|
| 88 |
formatted_prompt = format_prompt(description, columns)
|
| 89 |
payload = {"inputs": formatted_prompt, "parameters": generation_params}
|
|
|
|
| 92 |
|
| 93 |
def process_generated_data(csv_data, expected_columns):
|
| 94 |
try:
|
| 95 |
+
# Ensure the data is cleaned and correctly formatted
|
| 96 |
cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
|
| 97 |
data = StringIO(cleaned_data)
|
| 98 |
+
|
| 99 |
+
# Read the CSV data
|
| 100 |
df = pd.read_csv(data, delimiter=',')
|
| 101 |
+
|
| 102 |
+
# Check if the DataFrame has the expected columns
|
| 103 |
if set(df.columns) != set(expected_columns):
|
| 104 |
print(f"Unexpected columns in the generated data: {df.columns}")
|
| 105 |
return None
|
| 106 |
+
|
| 107 |
return df
|
| 108 |
except pd.errors.ParserError as e:
|
| 109 |
print(f"Failed to parse CSV data: {e}")
|
|
|
|
| 111 |
|
| 112 |
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
|
| 113 |
data_frames = []
|
| 114 |
+
|
| 115 |
for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
|
| 116 |
generated_data = generate_synthetic_data(description, columns)
|
| 117 |
df_synthetic = process_generated_data(generated_data, columns)
|
| 118 |
+
|
| 119 |
if df_synthetic is not None and not df_synthetic.empty:
|
| 120 |
data_frames.append(df_synthetic)
|
| 121 |
else:
|
| 122 |
print("Skipping invalid generation.")
|
| 123 |
+
|
| 124 |
if data_frames:
|
| 125 |
return pd.concat(data_frames, ignore_index=True)
|
| 126 |
else:
|
| 127 |
print("No valid data frames to concatenate.")
|
| 128 |
return pd.DataFrame(columns=columns)
|
| 129 |
|
| 130 |
+
|
| 131 |
+
|
| 132 |
@app.route('/generate', methods=['POST'])
|
| 133 |
def generate():
|
| 134 |
data = request.json
|