rstock / src /bin /api_server.rs
Spooker's picture
Upload 37 files
61a815f verified
//! HTTP API wrapper for bulk_download.
//!
//! Start this server, then create download jobs with:
//! POST /jobs {"date":20260224,"concurrent":50,"force":false}
use anyhow::{Context, Result as AnyResult};
use axum::{
extract::{Path, State},
http::StatusCode,
response::{Html, IntoResponse},
routing::{get, post},
Json, Router,
};
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
env,
net::SocketAddr,
path::PathBuf,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
};
use tokio::{process::Command, sync::RwLock};
use tracing::{error, info};
#[derive(Clone)]
struct AppState {
jobs: Arc<RwLock<HashMap<String, JobRecord>>>,
counter: Arc<AtomicU64>,
bulk_bin: PathBuf,
workdir: PathBuf,
}
#[derive(Debug, Deserialize)]
struct CreateJobRequest {
/// Trading date in YYYYMMDD format, for example 20260224.
date: u32,
/// Download concurrency. Defaults to 50.
#[serde(default = "default_concurrent")]
concurrent: usize,
/// true = force re-download, false = resume/incremental mode.
#[serde(default)]
force: bool,
}
fn default_concurrent() -> usize {
50
}
#[derive(Debug, Clone, Serialize)]
struct JobRecord {
id: String,
date: u32,
concurrent: usize,
force: bool,
status: JobStatus,
started_at: String,
finished_at: Option<String>,
exit_code: Option<i32>,
stdout_tail: Option<String>,
stderr_tail: Option<String>,
error: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "snake_case")]
enum JobStatus {
Running,
Succeeded,
Failed,
}
#[derive(Debug, Serialize)]
struct CreateJobResponse {
job_id: String,
status: JobStatus,
status_url: String,
}
#[derive(Debug, Serialize)]
struct HealthResponse {
status: &'static str,
bulk_bin: String,
workdir: String,
}
#[derive(Debug, Serialize)]
struct ErrorResponse {
error: String,
}
#[tokio::main]
async fn main() -> AnyResult<()> {
tracing_subscriber::fmt()
.with_env_filter(env::var("RUST_LOG").unwrap_or_else(|_| "info".to_string()))
.init();
let bulk_bin = resolve_bulk_download_bin();
let workdir = resolve_workdir()?;
let state = AppState {
jobs: Arc::new(RwLock::new(HashMap::new())),
counter: Arc::new(AtomicU64::new(1)),
bulk_bin,
workdir,
};
let app = Router::new()
.route("/", get(index))
.route("/ui", get(index))
.route("/health", get(health))
.route("/jobs", post(create_job).get(list_jobs))
.route("/jobs/:job_id", get(get_job))
.with_state(state.clone());
let host = env::var("API_HOST").unwrap_or_else(|_| "0.0.0.0".to_string());
let port = env::var("PORT")
.or_else(|_| env::var("API_PORT"))
.unwrap_or_else(|_| "7860".to_string())
.parse::<u16>()
.context("PORT/API_PORT must be a valid u16")?;
let addr: SocketAddr = format!("{}:{}", host, port)
.parse()
.context("invalid API_HOST/API_PORT")?;
info!("API server listening on http://{}", addr);
info!("bulk_download binary: {}", state.bulk_bin.display());
info!("working directory: {}", state.workdir.display());
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app).await?;
Ok(())
}
async fn index() -> Html<&'static str> {
Html(include_str!("../../static/index.html"))
}
async fn health(State(state): State<AppState>) -> Json<HealthResponse> {
Json(HealthResponse {
status: "ok",
bulk_bin: state.bulk_bin.display().to_string(),
workdir: state.workdir.display().to_string(),
})
}
async fn list_jobs(State(state): State<AppState>) -> Json<Vec<JobRecord>> {
let jobs = state.jobs.read().await;
let mut result: Vec<JobRecord> = jobs.values().cloned().collect();
result.sort_by(|a, b| b.started_at.cmp(&a.started_at));
Json(result)
}
async fn get_job(
State(state): State<AppState>,
Path(job_id): Path<String>,
) -> Result<Json<JobRecord>, (StatusCode, Json<ErrorResponse>)> {
let jobs = state.jobs.read().await;
match jobs.get(&job_id) {
Some(job) => Ok(Json(job.clone())),
None => Err(api_error(StatusCode::NOT_FOUND, "job not found")),
}
}
async fn create_job(
State(state): State<AppState>,
Json(req): Json<CreateJobRequest>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
validate_request(&req)?;
let job_id = format!(
"job-{}-{}",
chrono::Utc::now().format("%Y%m%d%H%M%S"),
state.counter.fetch_add(1, Ordering::Relaxed)
);
let record = JobRecord {
id: job_id.clone(),
date: req.date,
concurrent: req.concurrent,
force: req.force,
status: JobStatus::Running,
started_at: chrono::Utc::now().to_rfc3339(),
finished_at: None,
exit_code: None,
stdout_tail: None,
stderr_tail: None,
error: None,
};
state.jobs.write().await.insert(job_id.clone(), record);
let state_for_task = state.clone();
let job_id_for_task = job_id.clone();
tokio::spawn(async move {
run_bulk_download_job(state_for_task, job_id_for_task, req).await;
});
Ok((
StatusCode::ACCEPTED,
Json(CreateJobResponse {
job_id: job_id.clone(),
status: JobStatus::Running,
status_url: format!("/jobs/{}", job_id),
}),
))
}
fn validate_request(req: &CreateJobRequest) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
let year = req.date / 10000;
let month = (req.date / 100) % 100;
let day = req.date % 100;
if !(2000..=2099).contains(&year) || !(1..=12).contains(&month) || !(1..=31).contains(&day) {
return Err(api_error(
StatusCode::BAD_REQUEST,
"date must be in YYYYMMDD format, for example 20260224",
));
}
if req.concurrent == 0 || req.concurrent > 1000 {
return Err(api_error(
StatusCode::BAD_REQUEST,
"concurrent must be between 1 and 1000",
));
}
Ok(())
}
async fn run_bulk_download_job(state: AppState, job_id: String, req: CreateJobRequest) {
let mut command = Command::new(&state.bulk_bin);
command
.arg(req.date.to_string())
.arg(req.concurrent.to_string())
.current_dir(&state.workdir);
if req.force {
command.arg("--force");
}
info!("starting job {}", job_id);
let output = command.output().await;
let mut jobs = state.jobs.write().await;
if let Some(job) = jobs.get_mut(&job_id) {
job.finished_at = Some(chrono::Utc::now().to_rfc3339());
match output {
Ok(output) => {
let exit_code = output.status.code();
job.exit_code = exit_code;
job.stdout_tail = Some(tail_string(&String::from_utf8_lossy(&output.stdout), 64 * 1024));
job.stderr_tail = Some(tail_string(&String::from_utf8_lossy(&output.stderr), 64 * 1024));
job.status = if output.status.success() {
JobStatus::Succeeded
} else {
JobStatus::Failed
};
info!("job {} finished with status {:?}", job_id, job.status);
}
Err(err) => {
error!("job {} failed to start: {}", job_id, err);
job.status = JobStatus::Failed;
job.error = Some(err.to_string());
}
}
}
}
fn api_error(status: StatusCode, message: impl Into<String>) -> (StatusCode, Json<ErrorResponse>) {
(
status,
Json(ErrorResponse {
error: message.into(),
}),
)
}
fn tail_string(s: &str, max_bytes: usize) -> String {
if s.len() <= max_bytes {
return s.to_string();
}
let mut start = s.len().saturating_sub(max_bytes);
while start < s.len() && !s.is_char_boundary(start) {
start += 1;
}
format!("...\n{}", &s[start..])
}
fn resolve_bulk_download_bin() -> PathBuf {
if let Ok(path) = env::var("BULK_DOWNLOAD_BIN") {
return PathBuf::from(path);
}
if let Ok(current_exe) = env::current_exe() {
let sibling = current_exe.with_file_name("bulk_download");
if sibling.exists() {
return sibling;
}
}
PathBuf::from("bulk_download")
}
fn resolve_workdir() -> AnyResult<PathBuf> {
if let Ok(path) = env::var("BULK_DOWNLOAD_WORKDIR") {
return Ok(PathBuf::from(path));
}
env::current_dir().context("cannot resolve current working directory")
}