#!/usr/bin/env python3 """ 初始化SQLite数据库并导入CSV数据 """ import sqlite3 import csv from pathlib import Path from datetime import datetime DB_PATH = "data/customs_data.db" CSV_PATH = "data/standard_trade_records_sample.csv" def create_database(): """创建SQLite数据库和表结构""" print(f"创建数据库: {DB_PATH}") # 确保data目录存在 Path("data").mkdir(exist_ok=True) conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() # 删除旧表(如果存在) cursor.execute("DROP TABLE IF EXISTS standard_trade_records") # 创建标准贸易记录表 cursor.execute(""" CREATE TABLE standard_trade_records ( record_id TEXT PRIMARY KEY, source_record_id TEXT NOT NULL, batch_no TEXT NOT NULL, source_country TEXT NOT NULL, trade_direction TEXT NOT NULL, trade_date TEXT NOT NULL, importer_name TEXT, exporter_name TEXT, hs_code TEXT, product_name TEXT, amount REAL, currency TEXT, weight REAL, weight_unit TEXT, origin_country TEXT, destination_country TEXT, departure_port TEXT, arrival_port TEXT, transport_mode TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP ) """) # 创建索引 cursor.execute("CREATE INDEX IF NOT EXISTS idx_source_country ON standard_trade_records(source_country)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_trade_date ON standard_trade_records(trade_date)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_hs_code ON standard_trade_records(hs_code)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_trade_direction ON standard_trade_records(trade_direction)") conn.commit() return conn def import_csv_data(conn): """从CSV导入数据""" print(f"导入CSV数据: {CSV_PATH}") if not Path(CSV_PATH).exists(): print(f"警告: CSV文件不存在 {CSV_PATH}") return cursor = conn.cursor() # 先清空旧数据 cursor.execute("DELETE FROM standard_trade_records") conn.commit() print("清空旧数据") with open(CSV_PATH, 'r', encoding='utf-8') as f: reader = csv.DictReader(f) count = 0 batch = [] for row in reader: # 处理日期:确保有效的日期格式,避免空字符串 raw_date = row['trade_date'].strip() if row['trade_date'] else '' if not raw_date: trade_date = '2020-01-01' # 默认日期 else: # 去掉时间部分,只保留日期 trade_date = raw_date.split(' ')[0] if ' ' in raw_date else raw_date batch.append(( row['record_id'], row['source_record_id'], row['batch_no'], row['source_country'], row['trade_direction'], trade_date, row['importer_name'] or None, row['exporter_name'] or None, row['hs_code'] or None, row['product_name'] or None, float(row['amount']) if row['amount'] else None, row['currency'] or None, float(row['weight']) if row['weight'] else None, row['weight_unit'] or None, row['origin_country'] or None, row['destination_country'] or None, row['departure_port'] or None, row['arrival_port'] or None, row['transport_mode'] or None, row['created_at'], row.get('updated_at') )) count += 1 # 每1000条批量插入 if len(batch) >= 1000: cursor.executemany(""" INSERT INTO standard_trade_records VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? ) """, batch) conn.commit() print(f"已导入 {count} 条记录...") batch = [] # 插入剩余数据 if batch: cursor.executemany(""" INSERT INTO standard_trade_records VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? ) """, batch) conn.commit() print(f"总共导入 {count} 条记录") def main(): """主函数""" print("=== 初始化SQLite数据库 ===\n") conn = create_database() import_csv_data(conn) # 验证数据 cursor = conn.cursor() cursor.execute("SELECT COUNT(*) FROM standard_trade_records") total = cursor.fetchone()[0] print(f"\n数据库中共有 {total} 条记录") cursor.execute("SELECT COUNT(*) FROM standard_trade_records WHERE source_country='BR'") br_count = cursor.fetchone()[0] print(f"巴西数据: {br_count} 条") conn.close() print("\n✅ 数据库初始化完成!") if __name__ == "__main__": main()