File size: 4,420 Bytes
7c3e988
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
use anyhow::{Context, Result};
use clap::Parser;
use std::net::SocketAddr;
use std::path::PathBuf;
use tokio::sync::mpsc;
use tracing::info;

mod api;
mod models;
mod storage;
mod ws_client;

use api::build_router;
use models::AppState;
use storage::StorageWriter;
use ws_client::spawn_connection;

/// Sina Finance WebSocket real-time data API.
///
/// Connects to wss://hq.sinajs.cn/wskt, keeps latest stock tick data in memory,
/// exposes a REST API, and also persists raw records to daily CSV files.
#[derive(Parser)]
#[command(name = "sina-realtime-api", version)]
struct Cli {
    /// Stock list file — one code per line, e.g. sz300394 (lines starting with # are ignored)
    #[arg(short, long, default_value = "stocks_100.txt")]
    stocks: PathBuf,

    /// Directory for output CSV files
    #[arg(short, long, default_value = "data")]
    output: PathBuf,

    /// Max stocks per WebSocket connection (tune based on URL length / server limits)
    #[arg(long, default_value_t = 500)]
    chunk_size: usize,

    /// Internal channel buffer (records in flight between WS tasks and storage)
    #[arg(long, default_value_t = 131_072)]
    buffer: usize,

    /// API host. Hugging Face Spaces should use 0.0.0.0.
    #[arg(long, default_value = "0.0.0.0")]
    api_host: String,

    /// API port. If omitted, reads PORT env var, then falls back to 7860.
    #[arg(long)]
    api_port: Option<u16>,
}

#[tokio::main]
async fn main() -> Result<()> {
    tracing_subscriber::fmt()
        .with_env_filter(
            tracing_subscriber::EnvFilter::try_from_default_env()
                .unwrap_or_else(|_| "sina_realtime_collector=info,tower_http=warn".into()),
        )
        .init();

    let cli = Cli::parse();

    // Load and validate stock list
    let content = std::fs::read_to_string(&cli.stocks)
        .with_context(|| format!("Cannot read stock list: {:?}", cli.stocks))?;

    let stocks: Vec<String> = content
        .lines()
        .map(|l| l.trim().to_string())
        .filter(|l| !l.is_empty() && !l.starts_with('#'))
        .collect();

    anyhow::ensure!(!stocks.is_empty(), "Stock list is empty");
    info!("Loaded {} stocks from {:?}", stocks.len(), cli.stocks);

    std::fs::create_dir_all(&cli.output)
        .with_context(|| format!("Cannot create output dir: {:?}", cli.output))?;

    let started_at = chrono::Local::now()
        .format("%Y-%m-%dT%H:%M:%S%.3f")
        .to_string();
    let state = AppState::new(stocks.clone(), started_at);

    // Shared channel: WS tasks -> storage/API-state task
    let (tx, rx) = mpsc::channel::<String>(cli.buffer);

    // Storage also updates the latest-quote in-memory map used by the API.
    let storage_state = state.clone();
    let output_dir = cli.output.clone();
    let storage_handle = tokio::spawn(async move {
        if let Err(e) = StorageWriter::run(rx, output_dir, storage_state).await {
            tracing::error!("Storage error: {e:#}");
        }
    });

    // Split stocks into chunks, one WebSocket connection per chunk.
    let chunks: Vec<Vec<String>> = stocks
        .chunks(cli.chunk_size)
        .map(|c| c.to_vec())
        .collect();

    info!(
        "Starting {} connection(s) (~{} stocks each)",
        chunks.len(),
        cli.chunk_size
    );

    let mut conn_handles = Vec::with_capacity(chunks.len());
    for (i, chunk) in chunks.into_iter().enumerate() {
        let tx = tx.clone();
        conn_handles.push(tokio::spawn(spawn_connection(chunk, tx, i)));
    }
    drop(tx);

    let port = cli
        .api_port
        .or_else(|| std::env::var("PORT").ok().and_then(|v| v.parse::<u16>().ok()))
        .unwrap_or(7860);
    let addr: SocketAddr = format!("{}:{}", cli.api_host, port)
        .parse()
        .with_context(|| format!("Invalid API bind address: {}:{}", cli.api_host, port))?;

    let app = build_router(state);
    let listener = tokio::net::TcpListener::bind(addr).await?;
    info!("API listening on http://{addr}");

    let server_result = axum::serve(listener, app)
        .with_graceful_shutdown(shutdown_signal())
        .await;

    for h in conn_handles {
        h.abort();
    }
    storage_handle.abort();

    server_result?;
    Ok(())
}

async fn shutdown_signal() {
    if let Err(e) = tokio::signal::ctrl_c().await {
        tracing::warn!("Failed to listen for shutdown signal: {e}");
    }
    info!("Shutdown signal received");
}