aledraa commited on
Commit
7a21e38
·
verified ·
1 Parent(s): ff37ecd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +253 -0
app.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ import gradio as gr
3
+ import json
4
+ import torch
5
+ import re
6
+ import random
7
+
8
+ class TableDataGenerator:
9
+ def __init__(self, model_name="Qwen/Qwen2.5-3B-Instruct"):
10
+ self.model_name = model_name
11
+ self.model = AutoModelForCausalLM.from_pretrained(
12
+ model_name,
13
+ torch_dtype="auto",
14
+ device_map="auto"
15
+ )
16
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+
18
+ def generate_batch_data(self, llm_commands, num_rows=1000, batch_size=50):
19
+ """Generate table data in batches for better performance"""
20
+ all_rows = []
21
+
22
+ # Create column headers description
23
+ columns_desc = ", ".join([f"Column {i+1}: {cmd}" for i, cmd in enumerate(llm_commands)])
24
+
25
+ # Calculate number of batches
26
+ num_batches = (num_rows + batch_size - 1) // batch_size
27
+
28
+ for batch_idx in range(num_batches):
29
+ current_batch_size = min(batch_size, num_rows - len(all_rows))
30
+
31
+ # Create prompt for this batch
32
+ prompt = f"""Generate {current_batch_size} rows of realistic data for a table with these columns:
33
+ {columns_desc}
34
+
35
+ Requirements:
36
+ - Each row should be different and realistic
37
+ - Return ONLY a Python list format like: [['value1', 'value2'], ['value3', 'value4'], ...]
38
+ - Make the data diverse and realistic
39
+ - Use seed value {batch_idx + 1} for variety
40
+ - No explanations, just the list
41
+
42
+ Generate {current_batch_size} rows:"""
43
+
44
+ messages = [
45
+ {"role": "system", "content": "You are a data generator. Return only valid Python list format with realistic, diverse data."},
46
+ {"role": "user", "content": prompt}
47
+ ]
48
+
49
+ # Generate response
50
+ response = self._generate_response(messages)
51
+
52
+ # Parse the response to extract rows
53
+ batch_rows = self._parse_response(response, len(llm_commands))
54
+
55
+ # Add to all rows
56
+ all_rows.extend(batch_rows)
57
+
58
+ # Break if we have enough rows
59
+ if len(all_rows) >= num_rows:
60
+ break
61
+
62
+ return all_rows[:num_rows]
63
+
64
+ def _generate_response(self, messages):
65
+ """Generate response from the model"""
66
+ text = self.tokenizer.apply_chat_template(
67
+ messages,
68
+ tokenize=False,
69
+ add_generation_prompt=True
70
+ )
71
+
72
+ model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
73
+
74
+ # Set random seed for variety
75
+ torch.manual_seed(random.randint(1, 10000))
76
+
77
+ generated_ids = self.model.generate(
78
+ **model_inputs,
79
+ max_new_tokens=512,
80
+ temperature=0.8,
81
+ do_sample=True,
82
+ top_p=0.9
83
+ )
84
+
85
+ generated_ids = [
86
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
87
+ ]
88
+
89
+ response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
90
+ return response
91
+
92
+ def _parse_response(self, response, expected_columns):
93
+ """Parse the model response to extract table rows"""
94
+ rows = []
95
+
96
+ try:
97
+ # Try to find list-like patterns in the response
98
+ # Look for patterns like [['value1', 'value2'], ['value3', 'value4']]
99
+ list_pattern = r'\[\s*\[.*?\]\s*\]'
100
+ matches = re.findall(list_pattern, response, re.DOTALL)
101
+
102
+ if matches:
103
+ # Try to evaluate the largest match
104
+ largest_match = max(matches, key=len)
105
+ try:
106
+ parsed_data = eval(largest_match)
107
+ if isinstance(parsed_data, list):
108
+ for row in parsed_data:
109
+ if isinstance(row, list) and len(row) == expected_columns:
110
+ rows.append([str(item) for item in row])
111
+ except:
112
+ pass
113
+
114
+ # If no valid list found, try to extract individual rows
115
+ if not rows:
116
+ # Look for individual row patterns like ['value1', 'value2']
117
+ row_pattern = r'\[([^\[\]]+)\]'
118
+ row_matches = re.findall(row_pattern, response)
119
+
120
+ for match in row_matches:
121
+ try:
122
+ # Split by comma and clean up
123
+ items = [item.strip().strip('"\'') for item in match.split(',')]
124
+ if len(items) == expected_columns:
125
+ rows.append(items)
126
+ except:
127
+ continue
128
+
129
+ except Exception as e:
130
+ print(f"Error parsing response: {e}")
131
+
132
+ return rows
133
+
134
+ def generate_table_data(json_input, num_rows=1000):
135
+ """Main function to generate table data from JSON input"""
136
+ try:
137
+ # Parse JSON input
138
+ data = json.loads(json_input)
139
+ llm_commands = data.get('llm_commands', [])
140
+
141
+ if not llm_commands:
142
+ return "Error: No llm_commands found in JSON input"
143
+
144
+ # Initialize generator
145
+ generator = TableDataGenerator()
146
+
147
+ # Generate data
148
+ rows = generator.generate_batch_data(llm_commands, num_rows)
149
+
150
+ # Format output
151
+ result = f"Generated {len(rows)} rows:\n"
152
+ result += f"Columns: {llm_commands}\n\n"
153
+
154
+ # Show first 10 rows as preview
155
+ result += "First 10 rows:\n"
156
+ for i, row in enumerate(rows[:10]):
157
+ result += f"{i+1}: {row}\n"
158
+
159
+ if len(rows) > 10:
160
+ result += f"\n... and {len(rows) - 10} more rows"
161
+
162
+ return result, rows
163
+
164
+ except json.JSONDecodeError:
165
+ return "Error: Invalid JSON format", []
166
+ except Exception as e:
167
+ return f"Error: {str(e)}", []
168
+
169
+ # Gradio Interface
170
+ def process_json_input(json_input, num_rows):
171
+ """Process JSON input and return formatted results"""
172
+ result_text, rows = generate_table_data(json_input, int(num_rows))
173
+
174
+ # Also return the raw data as a downloadable file
175
+ if rows:
176
+ csv_content = "\n".join([",".join(row) for row in rows])
177
+ return result_text, csv_content
178
+ else:
179
+ return result_text, ""
180
+
181
+ # Create Gradio interface
182
+ with gr.Blocks(title="Table Data Generator") as demo:
183
+ gr.Markdown("# Table Data Generator using LLM")
184
+ gr.Markdown("Generate realistic table data based on column descriptions")
185
+
186
+ with gr.Row():
187
+ with gr.Column():
188
+ json_input = gr.Textbox(
189
+ label="JSON Input",
190
+ placeholder='{"llm_commands": ["ages between 1 to 20", "arabic name"]}',
191
+ lines=3,
192
+ value='{"llm_commands": ["ages between 1 to 20", "arabic name"]}'
193
+ )
194
+ num_rows = gr.Slider(
195
+ minimum=10,
196
+ maximum=2000,
197
+ value=100,
198
+ step=10,
199
+ label="Number of rows to generate"
200
+ )
201
+ generate_btn = gr.Button("Generate Data", variant="primary")
202
+
203
+ with gr.Column():
204
+ output_text = gr.Textbox(
205
+ label="Generated Data Preview",
206
+ lines=15,
207
+ max_lines=20
208
+ )
209
+ download_csv = gr.File(
210
+ label="Download CSV",
211
+ visible=True
212
+ )
213
+
214
+ def generate_and_save(json_input, num_rows):
215
+ result_text, csv_content = process_json_input(json_input, num_rows)
216
+
217
+ if csv_content:
218
+ # Save to temporary file
219
+ import tempfile
220
+ import os
221
+
222
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
223
+ f.write(csv_content)
224
+ temp_path = f.name
225
+
226
+ return result_text, temp_path
227
+ else:
228
+ return result_text, None
229
+
230
+ generate_btn.click(
231
+ fn=generate_and_save,
232
+ inputs=[json_input, num_rows],
233
+ outputs=[output_text, download_csv]
234
+ )
235
+
236
+ # Example inputs
237
+ gr.Examples(
238
+ examples=[
239
+ ['{"llm_commands": ["ages between 1 to 20", "arabic name"]}', 50],
240
+ ['{"llm_commands": ["random city", "population number", "country"]}', 100],
241
+ ['{"llm_commands": ["product name", "price in USD", "category"]}', 75],
242
+ ['{"llm_commands": ["email address", "phone number", "job title"]}', 60]
243
+ ],
244
+ inputs=[json_input, num_rows]
245
+ )
246
+
247
+ if __name__ == "__main__":
248
+ demo.launch()
249
+
250
+ # Example usage:
251
+ # json_input = '{"llm_commands": ["ages between 1 to 20", "arabic name"]}'
252
+ # result_text, rows = generate_table_data(json_input, 1000)
253
+ # print(result_text)