Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| 初始化SQLite数据库并导入CSV数据 | |
| """ | |
| import sqlite3 | |
| import csv | |
| from pathlib import Path | |
| from datetime import datetime | |
| DB_PATH = "data/customs_data.db" | |
| CSV_PATH = "data/standard_trade_records_sample.csv" | |
| def create_database(): | |
| """创建SQLite数据库和表结构""" | |
| print(f"创建数据库: {DB_PATH}") | |
| # 确保data目录存在 | |
| Path("data").mkdir(exist_ok=True) | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| # 删除旧表(如果存在) | |
| cursor.execute("DROP TABLE IF EXISTS standard_trade_records") | |
| # 创建标准贸易记录表 | |
| cursor.execute(""" | |
| CREATE TABLE standard_trade_records ( | |
| record_id TEXT PRIMARY KEY, | |
| source_record_id TEXT NOT NULL, | |
| batch_no TEXT NOT NULL, | |
| source_country TEXT NOT NULL, | |
| trade_direction TEXT NOT NULL, | |
| trade_date TEXT NOT NULL, | |
| importer_name TEXT, | |
| exporter_name TEXT, | |
| hs_code TEXT, | |
| product_name TEXT, | |
| amount REAL, | |
| currency TEXT, | |
| weight REAL, | |
| weight_unit TEXT, | |
| origin_country TEXT, | |
| destination_country TEXT, | |
| departure_port TEXT, | |
| arrival_port TEXT, | |
| transport_mode TEXT, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| updated_at TIMESTAMP | |
| ) | |
| """) | |
| # 创建索引 | |
| cursor.execute("CREATE INDEX IF NOT EXISTS idx_source_country ON standard_trade_records(source_country)") | |
| cursor.execute("CREATE INDEX IF NOT EXISTS idx_trade_date ON standard_trade_records(trade_date)") | |
| cursor.execute("CREATE INDEX IF NOT EXISTS idx_hs_code ON standard_trade_records(hs_code)") | |
| cursor.execute("CREATE INDEX IF NOT EXISTS idx_trade_direction ON standard_trade_records(trade_direction)") | |
| conn.commit() | |
| return conn | |
| def import_csv_data(conn): | |
| """从CSV导入数据""" | |
| print(f"导入CSV数据: {CSV_PATH}") | |
| if not Path(CSV_PATH).exists(): | |
| print(f"警告: CSV文件不存在 {CSV_PATH}") | |
| return | |
| cursor = conn.cursor() | |
| # 先清空旧数据 | |
| cursor.execute("DELETE FROM standard_trade_records") | |
| conn.commit() | |
| print("清空旧数据") | |
| with open(CSV_PATH, 'r', encoding='utf-8') as f: | |
| reader = csv.DictReader(f) | |
| count = 0 | |
| batch = [] | |
| for row in reader: | |
| # 处理日期:确保有效的日期格式,避免空字符串 | |
| raw_date = row['trade_date'].strip() if row['trade_date'] else '' | |
| if not raw_date: | |
| trade_date = '2020-01-01' # 默认日期 | |
| else: | |
| # 去掉时间部分,只保留日期 | |
| trade_date = raw_date.split(' ')[0] if ' ' in raw_date else raw_date | |
| batch.append(( | |
| row['record_id'], | |
| row['source_record_id'], | |
| row['batch_no'], | |
| row['source_country'], | |
| row['trade_direction'], | |
| trade_date, | |
| row['importer_name'] or None, | |
| row['exporter_name'] or None, | |
| row['hs_code'] or None, | |
| row['product_name'] or None, | |
| float(row['amount']) if row['amount'] else None, | |
| row['currency'] or None, | |
| float(row['weight']) if row['weight'] else None, | |
| row['weight_unit'] or None, | |
| row['origin_country'] or None, | |
| row['destination_country'] or None, | |
| row['departure_port'] or None, | |
| row['arrival_port'] or None, | |
| row['transport_mode'] or None, | |
| row['created_at'], | |
| row.get('updated_at') | |
| )) | |
| count += 1 | |
| # 每1000条批量插入 | |
| if len(batch) >= 1000: | |
| cursor.executemany(""" | |
| INSERT INTO standard_trade_records VALUES ( | |
| ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? | |
| ) | |
| """, batch) | |
| conn.commit() | |
| print(f"已导入 {count} 条记录...") | |
| batch = [] | |
| # 插入剩余数据 | |
| if batch: | |
| cursor.executemany(""" | |
| INSERT INTO standard_trade_records VALUES ( | |
| ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? | |
| ) | |
| """, batch) | |
| conn.commit() | |
| print(f"总共导入 {count} 条记录") | |
| def main(): | |
| """主函数""" | |
| print("=== 初始化SQLite数据库 ===\n") | |
| conn = create_database() | |
| import_csv_data(conn) | |
| # 验证数据 | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT COUNT(*) FROM standard_trade_records") | |
| total = cursor.fetchone()[0] | |
| print(f"\n数据库中共有 {total} 条记录") | |
| cursor.execute("SELECT COUNT(*) FROM standard_trade_records WHERE source_country='BR'") | |
| br_count = cursor.fetchone()[0] | |
| print(f"巴西数据: {br_count} 条") | |
| conn.close() | |
| print("\n✅ 数据库初始化完成!") | |
| if __name__ == "__main__": | |
| main() | |