Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -34,11 +34,13 @@ def generate_sample_data(user_query: str, groq_api_key: str) -> dict:
|
|
| 34 |
|
| 35 |
# Get current date for context
|
| 36 |
today = datetime.now().strftime('%Y-%m-%d')
|
|
|
|
|
|
|
| 37 |
|
| 38 |
# Request to generate table schema and sample data
|
| 39 |
schema_prompt = f"""Based on this query: "{user_query}"
|
| 40 |
|
| 41 |
-
Current date: {today}
|
| 42 |
|
| 43 |
Generate a realistic database schema with sample data. Return ONLY valid JSON with this structure:
|
| 44 |
{{
|
|
@@ -51,26 +53,43 @@ Generate a realistic database schema with sample data. Return ONLY valid JSON wi
|
|
| 51 |
],
|
| 52 |
"sample_data": [
|
| 53 |
{{"column_name": value, ...}},
|
| 54 |
-
...at least
|
| 55 |
]
|
| 56 |
}}
|
| 57 |
]
|
| 58 |
}}
|
| 59 |
|
| 60 |
-
|
| 61 |
-
1. For DATE columns: Use dates in format 'YYYY-MM-DD'. Include dates from the last 60 days to ensure some fall within "last 30 days"
|
| 62 |
-
2. For queries mentioning "last X days": Generate at least 50% of dates within that timeframe
|
| 63 |
-
3. For queries with amount/price filters (e.g., "over $500"): Ensure at least 40% of records meet the criteria
|
| 64 |
-
4. For queries with thresholds: Create data both above AND below the threshold
|
| 65 |
-
5. Use realistic names, emails, and values
|
| 66 |
-
6. Make sure there's enough variety in the data to produce meaningful query results
|
| 67 |
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
response = client.chat.completions.create(
|
| 71 |
model="moonshotai/kimi-k2-instruct-0905",
|
| 72 |
messages=[
|
| 73 |
-
{"role": "system", "content": "You are a database expert. Generate realistic table schemas and sample data
|
| 74 |
{"role": "user", "content": schema_prompt}
|
| 75 |
],
|
| 76 |
temperature=0.7
|
|
@@ -84,7 +103,7 @@ Example date range: from {(datetime.now() - timedelta(days=60)).strftime('%Y-%m-
|
|
| 84 |
|
| 85 |
schema_data = json.loads(content)
|
| 86 |
|
| 87 |
-
# Post-process: Enhance data to ensure query results
|
| 88 |
schema_data = enhance_sample_data(schema_data, user_query)
|
| 89 |
|
| 90 |
return schema_data
|
|
@@ -92,9 +111,11 @@ Example date range: from {(datetime.now() - timedelta(days=60)).strftime('%Y-%m-
|
|
| 92 |
raise Exception(f"Error generating sample data: {str(e)}")
|
| 93 |
|
| 94 |
def enhance_sample_data(schema_data: dict, user_query: str) -> dict:
|
| 95 |
-
"""Enhance sample data to ensure queries return results"""
|
|
|
|
|
|
|
| 96 |
|
| 97 |
-
# Detect if query mentions time period
|
| 98 |
time_keywords = {
|
| 99 |
'last 30 days': 30,
|
| 100 |
'last 60 days': 60,
|
|
@@ -107,51 +128,122 @@ def enhance_sample_data(schema_data: dict, user_query: str) -> dict:
|
|
| 107 |
|
| 108 |
days_back = None
|
| 109 |
for keyword, days in time_keywords.items():
|
| 110 |
-
if keyword in
|
| 111 |
days_back = days
|
| 112 |
break
|
| 113 |
|
| 114 |
# Detect amount/value thresholds
|
| 115 |
-
|
| 116 |
-
amount_match = re.search(r'over \$?(\d+)',
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
for table in schema_data['tables']:
|
| 120 |
enhanced_data = []
|
| 121 |
original_data = table['sample_data']
|
| 122 |
|
| 123 |
-
#
|
| 124 |
date_cols = [col['name'] for col in table['columns'] if col['type'] == 'DATE']
|
| 125 |
-
amount_cols = [col['name'] for col in table['columns']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
-
for row in original_data:
|
| 128 |
new_row = row.copy()
|
| 129 |
|
| 130 |
-
#
|
| 131 |
-
if
|
| 132 |
-
for date_col in
|
| 133 |
if date_col in new_row:
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
new_date = (datetime.now() - timedelta(days=random_days)).strftime('%Y-%m-%d')
|
| 137 |
new_row[date_col] = new_date
|
| 138 |
|
| 139 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
if threshold_amount and amount_cols:
|
| 141 |
for amount_col in amount_cols:
|
| 142 |
if amount_col in new_row:
|
| 143 |
-
#
|
| 144 |
-
if
|
| 145 |
-
|
|
|
|
| 146 |
else:
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
enhanced_data.append(new_row)
|
| 150 |
|
| 151 |
-
#
|
| 152 |
-
while len(enhanced_data) <
|
| 153 |
-
|
| 154 |
-
template_row = enhanced_data[
|
| 155 |
|
| 156 |
# Modify IDs to be unique
|
| 157 |
for col in table['columns']:
|
|
@@ -339,20 +431,28 @@ custom_css = """
|
|
| 339 |
width: 100%;
|
| 340 |
border-collapse: collapse;
|
| 341 |
margin: 10px 0;
|
|
|
|
| 342 |
}
|
| 343 |
.table th {
|
| 344 |
-
background-color: #
|
|
|
|
| 345 |
font-weight: bold;
|
| 346 |
-
padding:
|
| 347 |
text-align: left;
|
| 348 |
-
border: 1px solid #
|
| 349 |
}
|
| 350 |
.table td {
|
| 351 |
-
padding: 8px;
|
| 352 |
-
border: 1px solid #
|
| 353 |
}
|
| 354 |
.table-striped tbody tr:nth-child(odd) {
|
| 355 |
-
background-color: #
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
}
|
| 357 |
"""
|
| 358 |
|
|
|
|
| 34 |
|
| 35 |
# Get current date for context
|
| 36 |
today = datetime.now().strftime('%Y-%m-%d')
|
| 37 |
+
past_date_2y = (datetime.now() - timedelta(days=730)).strftime('%Y-%m-%d') # 2 years ago
|
| 38 |
+
past_date_60d = (datetime.now() - timedelta(days=60)).strftime('%Y-%m-%d') # 60 days ago
|
| 39 |
|
| 40 |
# Request to generate table schema and sample data
|
| 41 |
schema_prompt = f"""Based on this query: "{user_query}"
|
| 42 |
|
| 43 |
+
**Current date: {today}**
|
| 44 |
|
| 45 |
Generate a realistic database schema with sample data. Return ONLY valid JSON with this structure:
|
| 46 |
{{
|
|
|
|
| 53 |
],
|
| 54 |
"sample_data": [
|
| 55 |
{{"column_name": value, ...}},
|
| 56 |
+
...at least 20-25 rows
|
| 57 |
]
|
| 58 |
}}
|
| 59 |
]
|
| 60 |
}}
|
| 61 |
|
| 62 |
+
**CRITICAL INSTRUCTIONS FOR REALISTIC DATA:**
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
+
1. **DATES MUST BE IN THE PAST!**
|
| 65 |
+
- For hire_date, created_at, registration_date: Use dates between {past_date_2y} and {today}
|
| 66 |
+
- For order_date, transaction_date: If query mentions "last X days", use dates between {past_date_60d} and {today}
|
| 67 |
+
- NEVER use future dates!
|
| 68 |
+
|
| 69 |
+
2. **For NUMERIC filters (salary, amount, price):**
|
| 70 |
+
- If query says "over $80000", make 50-60% of records have values ABOVE 80000
|
| 71 |
+
- Create realistic variation: some at 85k, some at 95k, some at 120k, etc.
|
| 72 |
+
- Also include records BELOW the threshold (40-50%) for realism
|
| 73 |
+
|
| 74 |
+
3. **For TEXT filters (department, category, status):**
|
| 75 |
+
- If query mentions "Engineering department", ensure 50-60% of records have department = "Engineering"
|
| 76 |
+
- Include other departments too: "Marketing", "Sales", "HR", "Finance" for variety
|
| 77 |
+
|
| 78 |
+
4. **Data quality:**
|
| 79 |
+
- Use realistic names, emails (first.last@company.com format)
|
| 80 |
+
- Make data diverse and meaningful
|
| 81 |
+
- Ensure enough records match the query criteria to get meaningful results
|
| 82 |
+
|
| 83 |
+
Example: For "Find Engineering employees with salary > 80000"
|
| 84 |
+
- Create 20+ employee records
|
| 85 |
+
- 12-15 should be in Engineering (60%)
|
| 86 |
+
- Of Engineering employees, 8-10 should have salary > 80000
|
| 87 |
+
- Include other departments with various salaries for realism"""
|
| 88 |
|
| 89 |
response = client.chat.completions.create(
|
| 90 |
model="moonshotai/kimi-k2-instruct-0905",
|
| 91 |
messages=[
|
| 92 |
+
{"role": "system", "content": "You are a database expert. Generate realistic table schemas and sample data. ALL DATES MUST BE IN THE PAST, NEVER IN THE FUTURE. Return ONLY valid JSON, no markdown formatting."},
|
| 93 |
{"role": "user", "content": schema_prompt}
|
| 94 |
],
|
| 95 |
temperature=0.7
|
|
|
|
| 103 |
|
| 104 |
schema_data = json.loads(content)
|
| 105 |
|
| 106 |
+
# Post-process: Enhance and fix data to ensure query results
|
| 107 |
schema_data = enhance_sample_data(schema_data, user_query)
|
| 108 |
|
| 109 |
return schema_data
|
|
|
|
| 111 |
raise Exception(f"Error generating sample data: {str(e)}")
|
| 112 |
|
| 113 |
def enhance_sample_data(schema_data: dict, user_query: str) -> dict:
|
| 114 |
+
"""Enhance sample data to ensure queries return results and fix any date issues"""
|
| 115 |
+
|
| 116 |
+
query_lower = user_query.lower()
|
| 117 |
|
| 118 |
+
# Detect if query mentions time period (for order/transaction dates)
|
| 119 |
time_keywords = {
|
| 120 |
'last 30 days': 30,
|
| 121 |
'last 60 days': 60,
|
|
|
|
| 128 |
|
| 129 |
days_back = None
|
| 130 |
for keyword, days in time_keywords.items():
|
| 131 |
+
if keyword in query_lower:
|
| 132 |
days_back = days
|
| 133 |
break
|
| 134 |
|
| 135 |
# Detect amount/value thresholds
|
| 136 |
+
threshold_amount = None
|
| 137 |
+
amount_match = re.search(r'(?:over|above|greater than) \$?(\d+)', query_lower)
|
| 138 |
+
if amount_match:
|
| 139 |
+
threshold_amount = int(amount_match.group(1))
|
| 140 |
+
|
| 141 |
+
# Detect text filters (department, category, status, etc.)
|
| 142 |
+
text_filters = {}
|
| 143 |
+
|
| 144 |
+
# Department detection
|
| 145 |
+
dept_patterns = [
|
| 146 |
+
r'(?:in|from) (?:the )?(\w+) department',
|
| 147 |
+
r'department (?:is |= |== )?["\']?(\w+)["\']?',
|
| 148 |
+
r'(\w+) department',
|
| 149 |
+
]
|
| 150 |
+
for pattern in dept_patterns:
|
| 151 |
+
dept_match = re.search(pattern, query_lower)
|
| 152 |
+
if dept_match:
|
| 153 |
+
text_filters['department'] = dept_match.group(1).capitalize()
|
| 154 |
+
break
|
| 155 |
+
|
| 156 |
+
# Category detection
|
| 157 |
+
category_match = re.search(r'category (?:is |= )?["\']?(\w+)["\']?', query_lower)
|
| 158 |
+
if category_match:
|
| 159 |
+
text_filters['category'] = category_match.group(1).capitalize()
|
| 160 |
+
|
| 161 |
+
# Status detection
|
| 162 |
+
status_match = re.search(r'status (?:is |= )?["\']?(\w+)["\']?', query_lower)
|
| 163 |
+
if status_match:
|
| 164 |
+
text_filters['status'] = status_match.group(1).capitalize()
|
| 165 |
|
| 166 |
for table in schema_data['tables']:
|
| 167 |
enhanced_data = []
|
| 168 |
original_data = table['sample_data']
|
| 169 |
|
| 170 |
+
# Identify column types
|
| 171 |
date_cols = [col['name'] for col in table['columns'] if col['type'] == 'DATE']
|
| 172 |
+
amount_cols = [col['name'] for col in table['columns']
|
| 173 |
+
if any(keyword in col['name'].lower() for keyword in ['amount', 'price', 'salary', 'total', 'cost', 'revenue'])]
|
| 174 |
+
|
| 175 |
+
# Identify order/transaction date columns vs hire/created date columns
|
| 176 |
+
transaction_date_cols = [col for col in date_cols
|
| 177 |
+
if any(keyword in col.lower() for keyword in ['order', 'transaction', 'purchase', 'sale', 'payment'])]
|
| 178 |
+
other_date_cols = [col for col in date_cols if col not in transaction_date_cols]
|
| 179 |
|
| 180 |
+
for i, row in enumerate(original_data):
|
| 181 |
new_row = row.copy()
|
| 182 |
|
| 183 |
+
# FIX: Ensure transaction/order dates are in the past and within time period if specified
|
| 184 |
+
if transaction_date_cols:
|
| 185 |
+
for date_col in transaction_date_cols:
|
| 186 |
if date_col in new_row:
|
| 187 |
+
if days_back:
|
| 188 |
+
# Within specified period
|
| 189 |
+
random_days = random.randint(0, days_back)
|
| 190 |
+
else:
|
| 191 |
+
# Within last 60 days for transaction-type dates
|
| 192 |
+
random_days = random.randint(0, 60)
|
| 193 |
new_date = (datetime.now() - timedelta(days=random_days)).strftime('%Y-%m-%d')
|
| 194 |
new_row[date_col] = new_date
|
| 195 |
|
| 196 |
+
# FIX: Ensure other dates (hire_date, created_at, etc.) are in the PAST
|
| 197 |
+
if other_date_cols:
|
| 198 |
+
for date_col in other_date_cols:
|
| 199 |
+
if date_col in new_row:
|
| 200 |
+
try:
|
| 201 |
+
# Check if date is in the future
|
| 202 |
+
current_date = datetime.strptime(new_row[date_col], '%Y-%m-%d')
|
| 203 |
+
if current_date > datetime.now():
|
| 204 |
+
# Replace with a past date (random between 1 month to 3 years ago)
|
| 205 |
+
random_days = random.randint(30, 1095)
|
| 206 |
+
new_date = (datetime.now() - timedelta(days=random_days)).strftime('%Y-%m-%d')
|
| 207 |
+
new_row[date_col] = new_date
|
| 208 |
+
except:
|
| 209 |
+
# If date parsing fails, generate a new past date
|
| 210 |
+
random_days = random.randint(30, 1095)
|
| 211 |
+
new_date = (datetime.now() - timedelta(days=random_days)).strftime('%Y-%m-%d')
|
| 212 |
+
new_row[date_col] = new_date
|
| 213 |
+
|
| 214 |
+
# Enhance amount fields to match threshold
|
| 215 |
if threshold_amount and amount_cols:
|
| 216 |
for amount_col in amount_cols:
|
| 217 |
if amount_col in new_row:
|
| 218 |
+
# 55% of records above threshold, 45% below
|
| 219 |
+
if i % 100 < 55: # More deterministic distribution
|
| 220 |
+
# Above threshold
|
| 221 |
+
new_row[amount_col] = int(random.uniform(threshold_amount * 1.05, threshold_amount * 2.5))
|
| 222 |
else:
|
| 223 |
+
# Below threshold
|
| 224 |
+
new_row[amount_col] = int(random.uniform(threshold_amount * 0.4, threshold_amount * 0.95))
|
| 225 |
+
|
| 226 |
+
# Apply text filters to ensure enough matching records
|
| 227 |
+
for col_name, target_value in text_filters.items():
|
| 228 |
+
if col_name in new_row:
|
| 229 |
+
# 55% should match the filter value
|
| 230 |
+
if i % 100 < 55:
|
| 231 |
+
new_row[col_name] = target_value
|
| 232 |
+
else:
|
| 233 |
+
# Use other values for variety
|
| 234 |
+
if col_name == 'department':
|
| 235 |
+
other_depts = ['Marketing', 'Sales', 'HR', 'Finance', 'Operations', 'IT']
|
| 236 |
+
new_row[col_name] = random.choice([d for d in other_depts if d != target_value])
|
| 237 |
+
elif col_name == 'status':
|
| 238 |
+
other_statuses = ['Active', 'Inactive', 'Pending', 'Completed', 'Cancelled']
|
| 239 |
+
new_row[col_name] = random.choice([s for s in other_statuses if s != target_value])
|
| 240 |
|
| 241 |
enhanced_data.append(new_row)
|
| 242 |
|
| 243 |
+
# Ensure we have at least 20 rows
|
| 244 |
+
while len(enhanced_data) < 20:
|
| 245 |
+
template_idx = len(enhanced_data) % len(original_data)
|
| 246 |
+
template_row = enhanced_data[template_idx].copy()
|
| 247 |
|
| 248 |
# Modify IDs to be unique
|
| 249 |
for col in table['columns']:
|
|
|
|
| 431 |
width: 100%;
|
| 432 |
border-collapse: collapse;
|
| 433 |
margin: 10px 0;
|
| 434 |
+
font-size: 14px;
|
| 435 |
}
|
| 436 |
.table th {
|
| 437 |
+
background-color: #4a5568;
|
| 438 |
+
color: white;
|
| 439 |
font-weight: bold;
|
| 440 |
+
padding: 10px;
|
| 441 |
text-align: left;
|
| 442 |
+
border: 1px solid #2d3748;
|
| 443 |
}
|
| 444 |
.table td {
|
| 445 |
+
padding: 8px 10px;
|
| 446 |
+
border: 1px solid #e2e8f0;
|
| 447 |
}
|
| 448 |
.table-striped tbody tr:nth-child(odd) {
|
| 449 |
+
background-color: #f7fafc;
|
| 450 |
+
}
|
| 451 |
+
.table-striped tbody tr:nth-child(even) {
|
| 452 |
+
background-color: #ffffff;
|
| 453 |
+
}
|
| 454 |
+
.table-striped tbody tr:hover {
|
| 455 |
+
background-color: #edf2f7;
|
| 456 |
}
|
| 457 |
"""
|
| 458 |
|