JunhanCai's picture
Upload folder using huggingface_hub
8f6b390 verified
Raw
History Blame Contribute Delete
27.2 kB
import logging
# 设置日志级别为 DEBUG,能看到底层库在干什么
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
import multiprocessing
# 强行将多进程启动模式改回以前默认的 fork 模式
try:
multiprocessing.set_start_method('fork', force=True)
except RuntimeError:
pass
import os
import multiprocessing
import signal
import shutil
import uuid
import re
from datetime import datetime
from pathlib import Path
from threading import Thread
from typing import Optional
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import json
import torch
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
from fastapi.responses import FileResponse, JSONResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from webserver.train_service import TrainConfig, predict_with_checkpoint, run_finetune_job
from webserver.label_utils import load_label_mapping, apply_label_mapping
BASE_DIR = os.path.dirname(__file__)
UPLOAD_DIR = os.path.join(BASE_DIR, "uploads")
RUNS_DIR = os.path.join(BASE_DIR, "runs")
PREDICTIONS_DIR = os.path.join(BASE_DIR, "predictions")
TEMPLATE_DIR = os.path.join(BASE_DIR, "templates")
DEFAULT_MODEL_PATH = os.path.join(BASE_DIR, "weights", "Fine_tuned.pth")
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RUNS_DIR, exist_ok=True)
os.makedirs(PREDICTIONS_DIR, exist_ok=True)
app = FastAPI(title="Raman Fine-Tune Webserver")
templates = Jinja2Templates(directory=TEMPLATE_DIR)
# Initialize multiprocessing with fallback for constrained environments
try:
if multiprocessing.current_process().name == "MainProcess":
JOB_MANAGER = multiprocessing.Manager()
JOBS = JOB_MANAGER.dict()
else:
JOB_MANAGER = None
JOBS = {}
except Exception as e:
print(f"Warning: multiprocessing.Manager() failed: {e}. Using local dict instead.")
JOB_MANAGER = None
JOBS = {}
JOB_PROCESSES = {}
try:
JOB_CONTEXT = multiprocessing.get_context("spawn")
except Exception as e:
print(f"Warning: spawn context not available: {e}. Using default context.")
JOB_CONTEXT = multiprocessing
def _save_upload(file_obj: UploadFile, dst_path: str):
with open(dst_path, "wb") as out:
shutil.copyfileobj(file_obj.file, out)
def _load_report_text(report_path: str):
if not os.path.isfile(report_path):
return None
with open(report_path, "r", encoding="utf-8") as f:
return f.read()
def _build_artifact_entries(base_dir: str, artifact_map: dict, route_prefix: str):
entries = []
for key, filename in artifact_map.items():
file_path = os.path.join(base_dir, filename)
if not os.path.isfile(file_path):
continue
entries.append(
{
"key": key,
"filename": filename,
"url": f"/{route_prefix}/{os.path.basename(base_dir)}/{filename}",
"is_image": filename.lower().endswith((".png", ".jpg", ".jpeg", ".webp", ".gif")),
"is_text": filename.lower().endswith((".txt", ".json", ".csv")),
}
)
return entries
def _safe_result_file(root_dir: str, item_id: str, filename: str):
safe_name = os.path.basename(filename)
folder = os.path.join(root_dir, item_id)
file_path = os.path.join(folder, safe_name)
if not os.path.isfile(file_path):
raise HTTPException(status_code=404, detail="File not found")
return file_path
def _safe_uploaded_name(filename: str) -> str:
safe_name = os.path.basename(filename or "")
if not safe_name:
raise HTTPException(status_code=400, detail="Uploaded file is missing a filename")
return safe_name
def _is_optional_file(upload: Optional[UploadFile]) -> bool:
return upload is None or not getattr(upload, "filename", "") or not str(upload.filename).strip()
def _is_blank_upload(upload: Optional[UploadFile]) -> bool:
return upload is None or not getattr(upload, "filename", "") or not str(upload.filename).strip()
def _render_predict_results_fragment(
prediction_id: str,
summary: dict,
rows: list[dict],
top5_rows: list[dict],
download_csv: str,
preview_image: str,
):
return templates.env.get_template("predict_result_fragment.html").render(
prediction_id=prediction_id,
summary=summary,
rows=rows,
top5_rows=top5_rows,
download_csv=download_csv,
preview_image=preview_image,
)
def _parse_numeric_text_file(file_path: str) -> np.ndarray:
rows = []
with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
for raw_line in f:
line = raw_line.strip()
if not line or line.startswith("#"):
continue
tokens = [token for token in re.split(r"[\s,]+", line) if token]
values = []
for token in tokens:
try:
values.append(float(token))
except ValueError:
continue
if values:
rows.append(values)
if not rows:
raise ValueError("No numeric data found in text file")
max_cols = max(len(row) for row in rows)
if max_cols == 1:
return np.asarray([row[0] for row in rows], dtype=np.float32)
return np.asarray([[row[0], row[1]] for row in rows if len(row) >= 2], dtype=np.float32)
def _load_prediction_spectrum(file_path: str) -> tuple[np.ndarray, Optional[np.ndarray], str]:
extension = os.path.splitext(file_path)[1].lower()
if extension in {".txt", ".csv"}:
data = _parse_numeric_text_file(file_path)
if data.ndim == 1:
spectra = data.astype(np.float32).reshape(1, -1)
return spectra, None, "text_intensity_only"
if data.ndim == 2 and data.shape[1] >= 2:
wavenumbers = data[:, 0].astype(np.float32)
spectra = data[:, 1].astype(np.float32).reshape(1, -1)
return spectra, wavenumbers, "text_wavenumber_intensity"
raise ValueError("Text spectrum must contain either one intensity column or two columns: wavenumber, intensity")
if extension == ".npy":
spectra = np.load(file_path, allow_pickle=True)
return np.asarray(spectra, dtype=np.float32), None, "npy"
raise ValueError("Spectrum file must be .txt, .csv, or .npy")
def _load_prediction_wavenumbers(file_path: str) -> np.ndarray:
extension = os.path.splitext(file_path)[1].lower()
if extension in {".txt", ".csv"}:
data = _parse_numeric_text_file(file_path)
if data.ndim == 1:
return data.astype(np.float32).reshape(-1)
if data.ndim == 2 and data.shape[1] >= 1:
return data[:, 0].astype(np.float32).reshape(-1)
raise ValueError("Wavelength text file must contain one numeric column")
if extension == ".npy":
return np.asarray(np.load(file_path, allow_pickle=True), dtype=np.float32).reshape(-1)
raise ValueError("Wavelength file must be .txt, .csv, or .npy")
def _build_manual_wavenumbers(length: int, low_cm: float, high_cm: float) -> np.ndarray:
if low_cm is None or high_cm is None:
raise ValueError("Manual wavelength range requires both low and high values")
if high_cm <= low_cm:
raise ValueError("Manual wavelength range high value must be greater than low value")
return np.linspace(float(low_cm), float(high_cm), int(length), dtype=np.float32)
def _save_prediction_preview(prediction_dir: str, target_wavenumbers: np.ndarray, processed_spectra: np.ndarray) -> str:
spectra = np.asarray(processed_spectra, dtype=np.float32)
wavenumbers = np.asarray(target_wavenumbers, dtype=np.float32).reshape(-1)
if spectra.ndim != 2 or spectra.shape[1] != wavenumbers.shape[0]:
raise ValueError("processed spectra and wavenumbers must have matching 2D/1D shapes")
sample_count = spectra.shape[0]
preview_count = min(sample_count, 6)
fig, ax = plt.subplots(figsize=(8, 4.5))
for idx in range(preview_count):
label = f"Sample {idx + 1}" if sample_count > 1 else "Input spectrum"
ax.plot(wavenumbers, spectra[idx], linewidth=1.0, alpha=0.9, label=label)
ax.set_title(f"Input Spectra Preview ({sample_count} sample{'s' if sample_count != 1 else ''})")
ax.set_xlabel("Wavenumber (cm$^{-1}$)")
ax.set_ylabel("Normalized intensity")
ax.set_xlim(float(wavenumbers.min()), float(wavenumbers.max()))
ax.grid(True, linestyle="--", alpha=0.3)
if preview_count > 1:
ax.legend(frameon=False, fontsize=8)
fig.tight_layout()
preview_path = os.path.join(prediction_dir, "input_spectra_preview.png")
fig.savefig(preview_path, dpi=300, bbox_inches="tight")
plt.close(fig)
return preview_path
def _reap_job_process(job_id: str, process: multiprocessing.Process):
process.join()
JOB_PROCESSES.pop(job_id, None)
@app.on_event("startup")
async def startup_event():
import sys
print("[STARTUP] Application initializing...")
print(f"[STARTUP] Python version: {sys.version}")
print(f"[STARTUP] PyTorch: {torch.__version__}")
print(f"[STARTUP] CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"[STARTUP] CUDA device: {torch.cuda.get_device_name(0)}")
print(f"[STARTUP] JOB_MANAGER: {'multiprocessing.Manager' if JOB_MANAGER else 'local dict'}")
print(f"[STARTUP] JOB_CONTEXT: {type(JOB_CONTEXT).__name__}")
print("[STARTUP] Application ready!")
@app.get("/health")
def health_check():
return {"status": "ok", "cuda": torch.cuda.is_available()}
@app.get("/")
def index(request: Request):
return templates.TemplateResponse(request, "index.html", {"request": request})
@app.get("/predict")
def predict_page(request: Request):
return templates.TemplateResponse(request, "predict.html", {"request": request})
@app.post("/start")
def start_job(
request: Request,
spectral_file: UploadFile = File(...),
labels_file: UploadFile = File(...),
wavenumbers_file: UploadFile = File(...),
model_file: UploadFile = File(None),
label_mapping_file: Optional[UploadFile] = File(None),
epochs: int = Form(60),
lr: float = Form(1e-4),
weight_decay: float = Form(1e-3),
patience: int = Form(12),
batch_size: int = Form(64),
patch_num: int = Form(100),
embedding_dim: int = Form(512),
num_layers: int = Form(12),
num_heads: int = Form(16),
freeze_encoder: bool = Form(False),
label_smoothing: float = Form(0.0),
):
if _is_optional_file(label_mapping_file):
label_mapping_file = None
for f in [spectral_file, labels_file, wavenumbers_file]:
if not f.filename.endswith(".npy"):
raise HTTPException(status_code=400, detail=f"File {f.filename} must be .npy")
if not _is_blank_upload(model_file):
if not model_file.filename.endswith(".pth"):
raise HTTPException(status_code=400, detail="Model file must be .pth")
if label_mapping_file is not None and not _is_blank_upload(label_mapping_file):
ext = os.path.splitext(label_mapping_file.filename)[1].lower()
if ext not in {".json", ".txt"}:
raise HTTPException(status_code=400, detail="Label mapping file must be .json or .txt")
job_id = datetime.now().strftime("%Y%m%d_%H%M%S") + "_" + uuid.uuid4().hex[:8]
job_upload_dir = os.path.join(UPLOAD_DIR, job_id)
job_run_dir = os.path.join(RUNS_DIR, job_id)
os.makedirs(job_upload_dir, exist_ok=True)
os.makedirs(job_run_dir, exist_ok=True)
spectral_path = os.path.join(job_upload_dir, "spectral.npy")
labels_path = os.path.join(job_upload_dir, "labels.npy")
wavenumbers_path = os.path.join(job_upload_dir, "wavenumbers.npy")
if not _is_blank_upload(model_file):
model_path = os.path.join(job_upload_dir, "model.pth")
else:
model_path = None
label_mapping_path = os.path.join(job_upload_dir, _safe_uploaded_name(label_mapping_file.filename)) if label_mapping_file is not None else None
_save_upload(spectral_file, spectral_path)
_save_upload(labels_file, labels_path)
_save_upload(wavenumbers_file, wavenumbers_path)
if model_path is not None:
_save_upload(model_file, model_path)
if label_mapping_file is not None:
_save_upload(label_mapping_file, label_mapping_path)
config = TrainConfig(
epochs=epochs,
lr=lr,
weight_decay=weight_decay,
patience=patience,
batch_size=batch_size,
patch_num=patch_num,
embedding_dim=embedding_dim,
num_layers=num_layers,
num_heads=num_heads,
freeze_encoder=freeze_encoder,
label_smoothing=label_smoothing,
)
input_paths = {
"spectral": spectral_path,
"labels": labels_path,
"wavenumbers": wavenumbers_path,
"model": model_path,
"label_mapping": label_mapping_path,
}
JOBS[job_id] = {
"status": "queued",
"message": "Job queued",
"updated_at": datetime.now().isoformat(timespec="seconds"),
"progress": 0,
"phase": "queued",
"current_epoch": 0,
"total_epochs": epochs,
"device_label": "Detecting...",
"device_backend": "",
"device_name": "",
}
process = JOB_CONTEXT.Process(
target=run_finetune_job,
args=(job_id, input_paths, job_run_dir, config, JOBS),
daemon=False,
)
process.start()
JOB_PROCESSES[job_id] = process
Thread(target=_reap_job_process, args=(job_id, process), daemon=True).start()
job_record = dict(JOBS[job_id])
job_record["pid"] = process.pid
JOBS[job_id] = job_record
if request.headers.get("accept", "").find("application/json") >= 0 or request.headers.get("x-requested-with") == "XMLHttpRequest":
return JSONResponse({"job_id": job_id, "status_url": f"/status/{job_id}", "stop_url": f"/stop/{job_id}"})
return RedirectResponse(url=f"/status/{job_id}", status_code=303)
@app.post("/stop/{job_id}")
def stop_job(job_id: str):
if job_id not in JOBS:
raise HTTPException(status_code=404, detail="Job not found")
job = dict(JOBS[job_id])
if job.get("status") in {"done", "error", "cancelled"}:
raise HTTPException(status_code=409, detail="Job is already finished")
process = JOB_PROCESSES.get(job_id)
if process is not None:
if process.is_alive():
process.terminate()
process.join(timeout=5)
if process.is_alive():
process.kill()
process.join(timeout=5)
else:
pid = job.get("pid")
if pid:
try:
os.kill(int(pid), signal.SIGTERM)
except ProcessLookupError:
pass
JOBS[job_id] = {
**job,
"status": "cancelled",
"message": "Job cancelled by user",
"phase": "cancelled",
"progress": min(int(job.get("progress", 0) or 0), 99),
"updated_at": datetime.now().isoformat(timespec="seconds"),
}
return JSONResponse({"job_id": job_id, "status": "cancelled"})
@app.post("/predict")
def run_prediction(
request: Request,
spectral_file: UploadFile = File(...),
wavenumbers_file: Optional[UploadFile] = File(None),
model_file: UploadFile = File(...),
label_mapping_file: Optional[UploadFile] = File(None),
manual_low_cm: Optional[float] = Form(None),
manual_high_cm: Optional[float] = Form(None),
):
if _is_blank_upload(spectral_file):
raise HTTPException(status_code=400, detail="Please choose a spectral file before running prediction.")
if _is_blank_upload(model_file):
raise HTTPException(status_code=400, detail="Please choose a saved model (.pth) before running prediction.")
if _is_blank_upload(wavenumbers_file):
wavenumbers_file = None
if _is_optional_file(label_mapping_file):
label_mapping_file = None
spectral_name = _safe_uploaded_name(spectral_file.filename)
model_name = _safe_uploaded_name(model_file.filename)
wavenumbers_name = _safe_uploaded_name(wavenumbers_file.filename) if wavenumbers_file is not None else None
label_mapping_name = _safe_uploaded_name(label_mapping_file.filename) if label_mapping_file is not None else None
if os.path.splitext(model_name)[1].lower() != ".pth":
raise HTTPException(status_code=400, detail="Model file must be .pth")
spectral_ext = os.path.splitext(spectral_name)[1].lower()
if spectral_ext not in {".npy", ".txt", ".csv"}:
raise HTTPException(status_code=400, detail="Spectral file must be .npy, .txt, or .csv")
if wavenumbers_file is not None:
wavenumbers_ext = os.path.splitext(wavenumbers_name or "")[1].lower()
if wavenumbers_ext not in {".npy", ".txt", ".csv"}:
raise HTTPException(status_code=400, detail="Wavelength file must be .npy, .txt, or .csv")
if label_mapping_file is not None:
label_mapping_ext = os.path.splitext(label_mapping_name or "")[1].lower()
if label_mapping_ext not in {".json", ".txt"}:
raise HTTPException(status_code=400, detail="True label mapping file must be .json or .txt")
prediction_id = datetime.now().strftime("%Y%m%d_%H%M%S") + "_" + uuid.uuid4().hex[:8]
prediction_dir = os.path.join(PREDICTIONS_DIR, prediction_id)
os.makedirs(prediction_dir, exist_ok=True)
spectral_path = os.path.join(prediction_dir, spectral_name)
wavenumbers_path = os.path.join(prediction_dir, wavenumbers_name) if wavenumbers_name is not None else None
model_path = os.path.join(prediction_dir, model_name)
label_mapping_path = os.path.join(prediction_dir, label_mapping_name) if label_mapping_name is not None else None
_save_upload(spectral_file, spectral_path)
_save_upload(model_file, model_path)
if wavenumbers_file is not None:
_save_upload(wavenumbers_file, wavenumbers_path)
if label_mapping_file is not None:
_save_upload(label_mapping_file, label_mapping_path)
display_label_mapping = None
if label_mapping_path is not None:
display_label_mapping = load_label_mapping(label_mapping_path)
try:
spectral, inferred_wavenumbers, spectrum_source = _load_prediction_spectrum(spectral_path)
if inferred_wavenumbers is not None:
wavenumbers = inferred_wavenumbers
wavenumber_source = "embedded_in_spectrum"
elif wavenumbers_file is not None:
wavenumbers = _load_prediction_wavenumbers(wavenumbers_path)
wavenumber_source = "uploaded_wavelength_file"
elif manual_low_cm is not None or manual_high_cm is not None:
if manual_low_cm is None or manual_high_cm is None:
raise ValueError("Manual wavelength range requires both low and high values")
wavenumbers = _build_manual_wavenumbers(spectral.shape[-1], manual_low_cm, manual_high_cm)
wavenumber_source = "manual_range"
else:
raise HTTPException(
status_code=400,
detail="No wavelength information found in the spectrum file. Upload a wavelength file or provide a manual wavelength range.",
)
if spectral.ndim == 1:
spectral = spectral.reshape(1, -1)
if spectral.ndim != 2:
raise ValueError(f"Spectrum data must be 1D or 2D after loading, got shape {spectral.shape}")
preview_path = _save_prediction_preview(prediction_dir, wavenumbers, spectral)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
results = predict_with_checkpoint(model_path, spectral, wavenumbers, device, display_label_mapping=display_label_mapping)
except (ValueError, RuntimeError) as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
# Ensure we display mapped (human) labels when available. Prefer explicit display mapping
display_class_names = apply_label_mapping(
results.get("raw_class_names", results.get("class_names", [])),
display_label_mapping or results.get("checkpoint_label_mapping"),
)
top5_rows = []
top5_indices = np.argsort(results["logits"], axis=1)[:, ::-1][:, : min(5, results["logits"].shape[1])]
top5_logits = np.take_along_axis(results["logits"], top5_indices, axis=1)
for idx, (indices_row, logits_row) in enumerate(zip(top5_indices, top5_logits), start=1):
top5_rows.append(
{
"sample_index": idx,
"top5": [
{
"rank": rank + 1,
"class_name": display_class_names[class_idx] if class_idx < len(display_class_names) else str(class_idx),
"logit": float(logit_value),
}
for rank, (class_idx, logit_value) in enumerate(zip(indices_row.tolist(), logits_row.tolist()))
],
}
)
rows = []
for idx, (pred_index, confidence) in enumerate(
zip(results["pred_indices"], results["confidences"]),
start=1,
):
pred_index = int(pred_index)
pred_label_display = display_class_names[pred_index] if pred_index < len(display_class_names) else str(pred_index)
rows.append(
{
"sample_index": idx,
"pred_index": pred_index,
"pred_label": pred_label_display,
"confidence": float(confidence),
}
)
csv_path = os.path.join(prediction_dir, "predictions.csv")
with open(csv_path, "w", encoding="utf-8") as f:
f.write("sample_index,predicted_index,predicted_label,confidence\n")
for row in rows:
f.write(
f"{row['sample_index']},{row['pred_index']},{row['pred_label']},{row['confidence']:.6f}\n"
)
summary = {
"prediction_id": prediction_id,
"num_samples": len(rows),
"class_names": results["class_names"],
"raw_class_names": results.get("raw_class_names", []),
"model_config": results["model_config"],
"preprocess_config": results["preprocess_config"],
"download_csv": f"/predictions/{prediction_id}/predictions.csv",
"preview_image": f"/predictions/{prediction_id}/{os.path.basename(preview_path)}",
"spectrum_source": spectrum_source,
"wavenumber_source": wavenumber_source,
"label_mapping_source": label_mapping_name or ("checkpoint" if results.get("checkpoint_label_mapping") else None),
}
with open(os.path.join(prediction_dir, "prediction_summary.json"), "w", encoding="utf-8") as f:
json.dump(summary, f, indent=2, ensure_ascii=False)
if request.headers.get("accept", "").find("application/json") >= 0 or request.headers.get("x-requested-with") == "XMLHttpRequest":
return JSONResponse(
{
"prediction_id": prediction_id,
"summary": summary,
"rows": rows,
"top5_rows": top5_rows,
"download_csv": summary["download_csv"],
"preview_image": summary["preview_image"],
"results_html": _render_predict_results_fragment(
prediction_id,
summary,
rows,
top5_rows,
summary["download_csv"],
summary["preview_image"],
),
}
)
return templates.TemplateResponse(
request,
"predict.html",
{
"request": request,
"prediction_id": prediction_id,
"summary": summary,
"rows": rows,
"top5_rows": top5_rows,
"download_csv": summary["download_csv"],
"preview_image": summary["preview_image"],
},
)
@app.get("/status/{job_id}")
def status_page(job_id: str, request: Request):
if job_id not in JOBS:
raise HTTPException(status_code=404, detail="Job not found")
job = {
"status": "queued",
"message": "Job queued",
"updated_at": None,
"progress": 0,
"phase": "queued",
"current_epoch": 0,
"total_epochs": 0,
"device_label": "Detecting...",
"device_backend": "",
"device_name": "",
**JOBS[job_id],
}
if not job.get("total_epochs"):
job["total_epochs"] = 0
can_stop = job.get("status") in {"queued", "running"}
summary = job.get("summary", {}) or {}
artifact_map = summary.get("artifacts", {}) or {}
run_dir = os.path.join(RUNS_DIR, job_id)
report_path = os.path.join(run_dir, artifact_map.get("classification_report", "classification_report.txt"))
visual_keys = ["training_history", "tsne", "confusion_matrix"]
download_keys = [
"training_history",
"tsne",
"confusion_matrix",
"roc_curves",
"classification_report",
"final_model",
"best_class_model",
"best_recon_model",
]
visual_artifacts = []
download_artifacts = []
for key in visual_keys + download_keys:
filename = artifact_map.get(key)
if not filename:
continue
file_path = os.path.join(run_dir, filename)
if not os.path.isfile(file_path):
continue
artifact_info = {
"key": key,
"filename": filename,
"url": f"/runs/{job_id}/{filename}",
"is_image": filename.lower().endswith((".png", ".jpg", ".jpeg", ".webp", ".gif")),
}
if key in visual_keys and artifact_info["is_image"]:
visual_artifacts.append(artifact_info)
if key in download_keys:
download_artifacts.append(artifact_info)
return templates.TemplateResponse(
request,
"status.html",
{
"request": request,
"job_id": job_id,
"job": job,
"summary": summary,
"can_stop": can_stop,
"visual_artifacts": visual_artifacts,
"download_artifacts": download_artifacts,
"report_text": _load_report_text(report_path),
},
)
@app.get("/api/status/{job_id}")
def status_api(job_id: str):
if job_id not in JOBS:
raise HTTPException(status_code=404, detail="Job not found")
return JOBS[job_id]
@app.get("/runs/{job_id}/{filename}")
def job_artifact(job_id: str, filename: str):
file_path = _safe_result_file(RUNS_DIR, job_id, filename)
return FileResponse(file_path, filename=os.path.basename(file_path))
@app.get("/predictions/{prediction_id}/{filename}")
def prediction_artifact(prediction_id: str, filename: str):
file_path = _safe_result_file(PREDICTIONS_DIR, prediction_id, filename)
return FileResponse(file_path, filename=os.path.basename(file_path))