| import logging |
| |
| logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') |
| import multiprocessing |
| |
| try: |
| multiprocessing.set_start_method('fork', force=True) |
| except RuntimeError: |
| pass |
| import os |
| import multiprocessing |
| import signal |
| import shutil |
| import uuid |
| import re |
| from datetime import datetime |
| from pathlib import Path |
| from threading import Thread |
| from typing import Optional |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import json |
| import torch |
| from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile |
| from fastapi.responses import FileResponse, JSONResponse, RedirectResponse |
| from fastapi.templating import Jinja2Templates |
|
|
| from webserver.train_service import TrainConfig, predict_with_checkpoint, run_finetune_job |
| from webserver.label_utils import load_label_mapping, apply_label_mapping |
|
|
| BASE_DIR = os.path.dirname(__file__) |
| UPLOAD_DIR = os.path.join(BASE_DIR, "uploads") |
| RUNS_DIR = os.path.join(BASE_DIR, "runs") |
| PREDICTIONS_DIR = os.path.join(BASE_DIR, "predictions") |
| TEMPLATE_DIR = os.path.join(BASE_DIR, "templates") |
| DEFAULT_MODEL_PATH = os.path.join(BASE_DIR, "weights", "Fine_tuned.pth") |
| os.makedirs(UPLOAD_DIR, exist_ok=True) |
| os.makedirs(RUNS_DIR, exist_ok=True) |
| os.makedirs(PREDICTIONS_DIR, exist_ok=True) |
|
|
| app = FastAPI(title="Raman Fine-Tune Webserver") |
| templates = Jinja2Templates(directory=TEMPLATE_DIR) |
|
|
| |
| try: |
| if multiprocessing.current_process().name == "MainProcess": |
| JOB_MANAGER = multiprocessing.Manager() |
| JOBS = JOB_MANAGER.dict() |
| else: |
| JOB_MANAGER = None |
| JOBS = {} |
| except Exception as e: |
| print(f"Warning: multiprocessing.Manager() failed: {e}. Using local dict instead.") |
| JOB_MANAGER = None |
| JOBS = {} |
|
|
| JOB_PROCESSES = {} |
|
|
| try: |
| JOB_CONTEXT = multiprocessing.get_context("spawn") |
| except Exception as e: |
| print(f"Warning: spawn context not available: {e}. Using default context.") |
| JOB_CONTEXT = multiprocessing |
|
|
|
|
| def _save_upload(file_obj: UploadFile, dst_path: str): |
| with open(dst_path, "wb") as out: |
| shutil.copyfileobj(file_obj.file, out) |
|
|
|
|
| def _load_report_text(report_path: str): |
| if not os.path.isfile(report_path): |
| return None |
| with open(report_path, "r", encoding="utf-8") as f: |
| return f.read() |
|
|
|
|
| def _build_artifact_entries(base_dir: str, artifact_map: dict, route_prefix: str): |
| entries = [] |
| for key, filename in artifact_map.items(): |
| file_path = os.path.join(base_dir, filename) |
| if not os.path.isfile(file_path): |
| continue |
| entries.append( |
| { |
| "key": key, |
| "filename": filename, |
| "url": f"/{route_prefix}/{os.path.basename(base_dir)}/{filename}", |
| "is_image": filename.lower().endswith((".png", ".jpg", ".jpeg", ".webp", ".gif")), |
| "is_text": filename.lower().endswith((".txt", ".json", ".csv")), |
| } |
| ) |
| return entries |
|
|
|
|
| def _safe_result_file(root_dir: str, item_id: str, filename: str): |
| safe_name = os.path.basename(filename) |
| folder = os.path.join(root_dir, item_id) |
| file_path = os.path.join(folder, safe_name) |
| if not os.path.isfile(file_path): |
| raise HTTPException(status_code=404, detail="File not found") |
| return file_path |
|
|
|
|
| def _safe_uploaded_name(filename: str) -> str: |
| safe_name = os.path.basename(filename or "") |
| if not safe_name: |
| raise HTTPException(status_code=400, detail="Uploaded file is missing a filename") |
| return safe_name |
|
|
|
|
| def _is_optional_file(upload: Optional[UploadFile]) -> bool: |
| return upload is None or not getattr(upload, "filename", "") or not str(upload.filename).strip() |
|
|
|
|
| def _is_blank_upload(upload: Optional[UploadFile]) -> bool: |
| return upload is None or not getattr(upload, "filename", "") or not str(upload.filename).strip() |
|
|
|
|
| def _render_predict_results_fragment( |
| prediction_id: str, |
| summary: dict, |
| rows: list[dict], |
| top5_rows: list[dict], |
| download_csv: str, |
| preview_image: str, |
| ): |
| return templates.env.get_template("predict_result_fragment.html").render( |
| prediction_id=prediction_id, |
| summary=summary, |
| rows=rows, |
| top5_rows=top5_rows, |
| download_csv=download_csv, |
| preview_image=preview_image, |
| ) |
|
|
|
|
| def _parse_numeric_text_file(file_path: str) -> np.ndarray: |
| rows = [] |
| with open(file_path, "r", encoding="utf-8", errors="ignore") as f: |
| for raw_line in f: |
| line = raw_line.strip() |
| if not line or line.startswith("#"): |
| continue |
| tokens = [token for token in re.split(r"[\s,]+", line) if token] |
| values = [] |
| for token in tokens: |
| try: |
| values.append(float(token)) |
| except ValueError: |
| continue |
| if values: |
| rows.append(values) |
|
|
| if not rows: |
| raise ValueError("No numeric data found in text file") |
|
|
| max_cols = max(len(row) for row in rows) |
| if max_cols == 1: |
| return np.asarray([row[0] for row in rows], dtype=np.float32) |
|
|
| return np.asarray([[row[0], row[1]] for row in rows if len(row) >= 2], dtype=np.float32) |
|
|
|
|
| def _load_prediction_spectrum(file_path: str) -> tuple[np.ndarray, Optional[np.ndarray], str]: |
| extension = os.path.splitext(file_path)[1].lower() |
| if extension in {".txt", ".csv"}: |
| data = _parse_numeric_text_file(file_path) |
| if data.ndim == 1: |
| spectra = data.astype(np.float32).reshape(1, -1) |
| return spectra, None, "text_intensity_only" |
|
|
| if data.ndim == 2 and data.shape[1] >= 2: |
| wavenumbers = data[:, 0].astype(np.float32) |
| spectra = data[:, 1].astype(np.float32).reshape(1, -1) |
| return spectra, wavenumbers, "text_wavenumber_intensity" |
|
|
| raise ValueError("Text spectrum must contain either one intensity column or two columns: wavenumber, intensity") |
|
|
| if extension == ".npy": |
| spectra = np.load(file_path, allow_pickle=True) |
| return np.asarray(spectra, dtype=np.float32), None, "npy" |
|
|
| raise ValueError("Spectrum file must be .txt, .csv, or .npy") |
|
|
|
|
| def _load_prediction_wavenumbers(file_path: str) -> np.ndarray: |
| extension = os.path.splitext(file_path)[1].lower() |
| if extension in {".txt", ".csv"}: |
| data = _parse_numeric_text_file(file_path) |
| if data.ndim == 1: |
| return data.astype(np.float32).reshape(-1) |
| if data.ndim == 2 and data.shape[1] >= 1: |
| return data[:, 0].astype(np.float32).reshape(-1) |
| raise ValueError("Wavelength text file must contain one numeric column") |
|
|
| if extension == ".npy": |
| return np.asarray(np.load(file_path, allow_pickle=True), dtype=np.float32).reshape(-1) |
|
|
| raise ValueError("Wavelength file must be .txt, .csv, or .npy") |
|
|
|
|
| def _build_manual_wavenumbers(length: int, low_cm: float, high_cm: float) -> np.ndarray: |
| if low_cm is None or high_cm is None: |
| raise ValueError("Manual wavelength range requires both low and high values") |
| if high_cm <= low_cm: |
| raise ValueError("Manual wavelength range high value must be greater than low value") |
| return np.linspace(float(low_cm), float(high_cm), int(length), dtype=np.float32) |
|
|
|
|
| def _save_prediction_preview(prediction_dir: str, target_wavenumbers: np.ndarray, processed_spectra: np.ndarray) -> str: |
| spectra = np.asarray(processed_spectra, dtype=np.float32) |
| wavenumbers = np.asarray(target_wavenumbers, dtype=np.float32).reshape(-1) |
| if spectra.ndim != 2 or spectra.shape[1] != wavenumbers.shape[0]: |
| raise ValueError("processed spectra and wavenumbers must have matching 2D/1D shapes") |
|
|
| sample_count = spectra.shape[0] |
| preview_count = min(sample_count, 6) |
| fig, ax = plt.subplots(figsize=(8, 4.5)) |
| for idx in range(preview_count): |
| label = f"Sample {idx + 1}" if sample_count > 1 else "Input spectrum" |
| ax.plot(wavenumbers, spectra[idx], linewidth=1.0, alpha=0.9, label=label) |
|
|
| ax.set_title(f"Input Spectra Preview ({sample_count} sample{'s' if sample_count != 1 else ''})") |
| ax.set_xlabel("Wavenumber (cm$^{-1}$)") |
| ax.set_ylabel("Normalized intensity") |
| ax.set_xlim(float(wavenumbers.min()), float(wavenumbers.max())) |
| ax.grid(True, linestyle="--", alpha=0.3) |
| if preview_count > 1: |
| ax.legend(frameon=False, fontsize=8) |
| fig.tight_layout() |
|
|
| preview_path = os.path.join(prediction_dir, "input_spectra_preview.png") |
| fig.savefig(preview_path, dpi=300, bbox_inches="tight") |
| plt.close(fig) |
| return preview_path |
|
|
|
|
| def _reap_job_process(job_id: str, process: multiprocessing.Process): |
| process.join() |
| JOB_PROCESSES.pop(job_id, None) |
|
|
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| import sys |
| print("[STARTUP] Application initializing...") |
| print(f"[STARTUP] Python version: {sys.version}") |
| print(f"[STARTUP] PyTorch: {torch.__version__}") |
| print(f"[STARTUP] CUDA available: {torch.cuda.is_available()}") |
| if torch.cuda.is_available(): |
| print(f"[STARTUP] CUDA device: {torch.cuda.get_device_name(0)}") |
| print(f"[STARTUP] JOB_MANAGER: {'multiprocessing.Manager' if JOB_MANAGER else 'local dict'}") |
| print(f"[STARTUP] JOB_CONTEXT: {type(JOB_CONTEXT).__name__}") |
| print("[STARTUP] Application ready!") |
|
|
|
|
| @app.get("/health") |
| def health_check(): |
| return {"status": "ok", "cuda": torch.cuda.is_available()} |
|
|
|
|
| @app.get("/") |
| def index(request: Request): |
| return templates.TemplateResponse(request, "index.html", {"request": request}) |
|
|
|
|
| @app.get("/predict") |
| def predict_page(request: Request): |
| return templates.TemplateResponse(request, "predict.html", {"request": request}) |
|
|
|
|
| @app.post("/start") |
| def start_job( |
| request: Request, |
| spectral_file: UploadFile = File(...), |
| labels_file: UploadFile = File(...), |
| wavenumbers_file: UploadFile = File(...), |
| model_file: UploadFile = File(None), |
| label_mapping_file: Optional[UploadFile] = File(None), |
| epochs: int = Form(60), |
| lr: float = Form(1e-4), |
| weight_decay: float = Form(1e-3), |
| patience: int = Form(12), |
| batch_size: int = Form(64), |
| patch_num: int = Form(100), |
| embedding_dim: int = Form(512), |
| num_layers: int = Form(12), |
| num_heads: int = Form(16), |
| freeze_encoder: bool = Form(False), |
| label_smoothing: float = Form(0.0), |
| ): |
| if _is_optional_file(label_mapping_file): |
| label_mapping_file = None |
|
|
| for f in [spectral_file, labels_file, wavenumbers_file]: |
| if not f.filename.endswith(".npy"): |
| raise HTTPException(status_code=400, detail=f"File {f.filename} must be .npy") |
| if not _is_blank_upload(model_file): |
| if not model_file.filename.endswith(".pth"): |
| raise HTTPException(status_code=400, detail="Model file must be .pth") |
| if label_mapping_file is not None and not _is_blank_upload(label_mapping_file): |
| ext = os.path.splitext(label_mapping_file.filename)[1].lower() |
| if ext not in {".json", ".txt"}: |
| raise HTTPException(status_code=400, detail="Label mapping file must be .json or .txt") |
| |
| job_id = datetime.now().strftime("%Y%m%d_%H%M%S") + "_" + uuid.uuid4().hex[:8] |
| job_upload_dir = os.path.join(UPLOAD_DIR, job_id) |
| job_run_dir = os.path.join(RUNS_DIR, job_id) |
| os.makedirs(job_upload_dir, exist_ok=True) |
| os.makedirs(job_run_dir, exist_ok=True) |
|
|
| spectral_path = os.path.join(job_upload_dir, "spectral.npy") |
| labels_path = os.path.join(job_upload_dir, "labels.npy") |
| wavenumbers_path = os.path.join(job_upload_dir, "wavenumbers.npy") |
| if not _is_blank_upload(model_file): |
| model_path = os.path.join(job_upload_dir, "model.pth") |
| else: |
| model_path = None |
|
|
| label_mapping_path = os.path.join(job_upload_dir, _safe_uploaded_name(label_mapping_file.filename)) if label_mapping_file is not None else None |
|
|
| _save_upload(spectral_file, spectral_path) |
| _save_upload(labels_file, labels_path) |
| _save_upload(wavenumbers_file, wavenumbers_path) |
| |
| if model_path is not None: |
| _save_upload(model_file, model_path) |
| if label_mapping_file is not None: |
| _save_upload(label_mapping_file, label_mapping_path) |
| |
| config = TrainConfig( |
| epochs=epochs, |
| lr=lr, |
| weight_decay=weight_decay, |
| patience=patience, |
| batch_size=batch_size, |
| patch_num=patch_num, |
| embedding_dim=embedding_dim, |
| num_layers=num_layers, |
| num_heads=num_heads, |
| freeze_encoder=freeze_encoder, |
| label_smoothing=label_smoothing, |
| ) |
|
|
| input_paths = { |
| "spectral": spectral_path, |
| "labels": labels_path, |
| "wavenumbers": wavenumbers_path, |
| "model": model_path, |
| "label_mapping": label_mapping_path, |
| } |
|
|
| JOBS[job_id] = { |
| "status": "queued", |
| "message": "Job queued", |
| "updated_at": datetime.now().isoformat(timespec="seconds"), |
| "progress": 0, |
| "phase": "queued", |
| "current_epoch": 0, |
| "total_epochs": epochs, |
| "device_label": "Detecting...", |
| "device_backend": "", |
| "device_name": "", |
| } |
|
|
| process = JOB_CONTEXT.Process( |
| target=run_finetune_job, |
| args=(job_id, input_paths, job_run_dir, config, JOBS), |
| daemon=False, |
| ) |
| process.start() |
| JOB_PROCESSES[job_id] = process |
| Thread(target=_reap_job_process, args=(job_id, process), daemon=True).start() |
| job_record = dict(JOBS[job_id]) |
| job_record["pid"] = process.pid |
| JOBS[job_id] = job_record |
|
|
| if request.headers.get("accept", "").find("application/json") >= 0 or request.headers.get("x-requested-with") == "XMLHttpRequest": |
| return JSONResponse({"job_id": job_id, "status_url": f"/status/{job_id}", "stop_url": f"/stop/{job_id}"}) |
|
|
| return RedirectResponse(url=f"/status/{job_id}", status_code=303) |
|
|
|
|
| @app.post("/stop/{job_id}") |
| def stop_job(job_id: str): |
| if job_id not in JOBS: |
| raise HTTPException(status_code=404, detail="Job not found") |
|
|
| job = dict(JOBS[job_id]) |
| if job.get("status") in {"done", "error", "cancelled"}: |
| raise HTTPException(status_code=409, detail="Job is already finished") |
|
|
| process = JOB_PROCESSES.get(job_id) |
| if process is not None: |
| if process.is_alive(): |
| process.terminate() |
| process.join(timeout=5) |
| if process.is_alive(): |
| process.kill() |
| process.join(timeout=5) |
| else: |
| pid = job.get("pid") |
| if pid: |
| try: |
| os.kill(int(pid), signal.SIGTERM) |
| except ProcessLookupError: |
| pass |
|
|
| JOBS[job_id] = { |
| **job, |
| "status": "cancelled", |
| "message": "Job cancelled by user", |
| "phase": "cancelled", |
| "progress": min(int(job.get("progress", 0) or 0), 99), |
| "updated_at": datetime.now().isoformat(timespec="seconds"), |
| } |
| return JSONResponse({"job_id": job_id, "status": "cancelled"}) |
|
|
|
|
| @app.post("/predict") |
| def run_prediction( |
| request: Request, |
| spectral_file: UploadFile = File(...), |
| wavenumbers_file: Optional[UploadFile] = File(None), |
| model_file: UploadFile = File(...), |
| label_mapping_file: Optional[UploadFile] = File(None), |
| manual_low_cm: Optional[float] = Form(None), |
| manual_high_cm: Optional[float] = Form(None), |
| ): |
| if _is_blank_upload(spectral_file): |
| raise HTTPException(status_code=400, detail="Please choose a spectral file before running prediction.") |
| |
| if _is_blank_upload(model_file): |
| raise HTTPException(status_code=400, detail="Please choose a saved model (.pth) before running prediction.") |
|
|
| if _is_blank_upload(wavenumbers_file): |
| wavenumbers_file = None |
|
|
| if _is_optional_file(label_mapping_file): |
| label_mapping_file = None |
|
|
| spectral_name = _safe_uploaded_name(spectral_file.filename) |
| model_name = _safe_uploaded_name(model_file.filename) |
| wavenumbers_name = _safe_uploaded_name(wavenumbers_file.filename) if wavenumbers_file is not None else None |
| label_mapping_name = _safe_uploaded_name(label_mapping_file.filename) if label_mapping_file is not None else None |
|
|
| if os.path.splitext(model_name)[1].lower() != ".pth": |
| raise HTTPException(status_code=400, detail="Model file must be .pth") |
|
|
| spectral_ext = os.path.splitext(spectral_name)[1].lower() |
| if spectral_ext not in {".npy", ".txt", ".csv"}: |
| raise HTTPException(status_code=400, detail="Spectral file must be .npy, .txt, or .csv") |
|
|
| if wavenumbers_file is not None: |
| wavenumbers_ext = os.path.splitext(wavenumbers_name or "")[1].lower() |
| if wavenumbers_ext not in {".npy", ".txt", ".csv"}: |
| raise HTTPException(status_code=400, detail="Wavelength file must be .npy, .txt, or .csv") |
|
|
| if label_mapping_file is not None: |
| label_mapping_ext = os.path.splitext(label_mapping_name or "")[1].lower() |
| if label_mapping_ext not in {".json", ".txt"}: |
| raise HTTPException(status_code=400, detail="True label mapping file must be .json or .txt") |
|
|
| prediction_id = datetime.now().strftime("%Y%m%d_%H%M%S") + "_" + uuid.uuid4().hex[:8] |
| prediction_dir = os.path.join(PREDICTIONS_DIR, prediction_id) |
| os.makedirs(prediction_dir, exist_ok=True) |
|
|
| spectral_path = os.path.join(prediction_dir, spectral_name) |
| wavenumbers_path = os.path.join(prediction_dir, wavenumbers_name) if wavenumbers_name is not None else None |
| model_path = os.path.join(prediction_dir, model_name) |
| label_mapping_path = os.path.join(prediction_dir, label_mapping_name) if label_mapping_name is not None else None |
|
|
| _save_upload(spectral_file, spectral_path) |
| _save_upload(model_file, model_path) |
| if wavenumbers_file is not None: |
| _save_upload(wavenumbers_file, wavenumbers_path) |
| if label_mapping_file is not None: |
| _save_upload(label_mapping_file, label_mapping_path) |
|
|
| display_label_mapping = None |
| if label_mapping_path is not None: |
| display_label_mapping = load_label_mapping(label_mapping_path) |
|
|
| try: |
| spectral, inferred_wavenumbers, spectrum_source = _load_prediction_spectrum(spectral_path) |
|
|
| if inferred_wavenumbers is not None: |
| wavenumbers = inferred_wavenumbers |
| wavenumber_source = "embedded_in_spectrum" |
| elif wavenumbers_file is not None: |
| wavenumbers = _load_prediction_wavenumbers(wavenumbers_path) |
| wavenumber_source = "uploaded_wavelength_file" |
| elif manual_low_cm is not None or manual_high_cm is not None: |
| if manual_low_cm is None or manual_high_cm is None: |
| raise ValueError("Manual wavelength range requires both low and high values") |
| wavenumbers = _build_manual_wavenumbers(spectral.shape[-1], manual_low_cm, manual_high_cm) |
| wavenumber_source = "manual_range" |
| else: |
| raise HTTPException( |
| status_code=400, |
| detail="No wavelength information found in the spectrum file. Upload a wavelength file or provide a manual wavelength range.", |
| ) |
|
|
| if spectral.ndim == 1: |
| spectral = spectral.reshape(1, -1) |
| if spectral.ndim != 2: |
| raise ValueError(f"Spectrum data must be 1D or 2D after loading, got shape {spectral.shape}") |
|
|
| preview_path = _save_prediction_preview(prediction_dir, wavenumbers, spectral) |
| except ValueError as exc: |
| raise HTTPException(status_code=400, detail=str(exc)) from exc |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| try: |
| results = predict_with_checkpoint(model_path, spectral, wavenumbers, device, display_label_mapping=display_label_mapping) |
| except (ValueError, RuntimeError) as exc: |
| raise HTTPException(status_code=400, detail=str(exc)) from exc |
|
|
| |
| display_class_names = apply_label_mapping( |
| results.get("raw_class_names", results.get("class_names", [])), |
| display_label_mapping or results.get("checkpoint_label_mapping"), |
| ) |
|
|
| top5_rows = [] |
| top5_indices = np.argsort(results["logits"], axis=1)[:, ::-1][:, : min(5, results["logits"].shape[1])] |
| top5_logits = np.take_along_axis(results["logits"], top5_indices, axis=1) |
| for idx, (indices_row, logits_row) in enumerate(zip(top5_indices, top5_logits), start=1): |
| top5_rows.append( |
| { |
| "sample_index": idx, |
| "top5": [ |
| { |
| "rank": rank + 1, |
| "class_name": display_class_names[class_idx] if class_idx < len(display_class_names) else str(class_idx), |
| "logit": float(logit_value), |
| } |
| for rank, (class_idx, logit_value) in enumerate(zip(indices_row.tolist(), logits_row.tolist())) |
| ], |
| } |
| ) |
|
|
| rows = [] |
| for idx, (pred_index, confidence) in enumerate( |
| zip(results["pred_indices"], results["confidences"]), |
| start=1, |
| ): |
| pred_index = int(pred_index) |
| pred_label_display = display_class_names[pred_index] if pred_index < len(display_class_names) else str(pred_index) |
| rows.append( |
| { |
| "sample_index": idx, |
| "pred_index": pred_index, |
| "pred_label": pred_label_display, |
| "confidence": float(confidence), |
| } |
| ) |
|
|
| csv_path = os.path.join(prediction_dir, "predictions.csv") |
| with open(csv_path, "w", encoding="utf-8") as f: |
| f.write("sample_index,predicted_index,predicted_label,confidence\n") |
| for row in rows: |
| f.write( |
| f"{row['sample_index']},{row['pred_index']},{row['pred_label']},{row['confidence']:.6f}\n" |
| ) |
|
|
| summary = { |
| "prediction_id": prediction_id, |
| "num_samples": len(rows), |
| "class_names": results["class_names"], |
| "raw_class_names": results.get("raw_class_names", []), |
| "model_config": results["model_config"], |
| "preprocess_config": results["preprocess_config"], |
| "download_csv": f"/predictions/{prediction_id}/predictions.csv", |
| "preview_image": f"/predictions/{prediction_id}/{os.path.basename(preview_path)}", |
| "spectrum_source": spectrum_source, |
| "wavenumber_source": wavenumber_source, |
| "label_mapping_source": label_mapping_name or ("checkpoint" if results.get("checkpoint_label_mapping") else None), |
| } |
| with open(os.path.join(prediction_dir, "prediction_summary.json"), "w", encoding="utf-8") as f: |
| json.dump(summary, f, indent=2, ensure_ascii=False) |
|
|
| if request.headers.get("accept", "").find("application/json") >= 0 or request.headers.get("x-requested-with") == "XMLHttpRequest": |
| return JSONResponse( |
| { |
| "prediction_id": prediction_id, |
| "summary": summary, |
| "rows": rows, |
| "top5_rows": top5_rows, |
| "download_csv": summary["download_csv"], |
| "preview_image": summary["preview_image"], |
| "results_html": _render_predict_results_fragment( |
| prediction_id, |
| summary, |
| rows, |
| top5_rows, |
| summary["download_csv"], |
| summary["preview_image"], |
| ), |
| } |
| ) |
|
|
| return templates.TemplateResponse( |
| request, |
| "predict.html", |
| { |
| "request": request, |
| "prediction_id": prediction_id, |
| "summary": summary, |
| "rows": rows, |
| "top5_rows": top5_rows, |
| "download_csv": summary["download_csv"], |
| "preview_image": summary["preview_image"], |
| }, |
| ) |
|
|
|
|
| @app.get("/status/{job_id}") |
| def status_page(job_id: str, request: Request): |
| if job_id not in JOBS: |
| raise HTTPException(status_code=404, detail="Job not found") |
| job = { |
| "status": "queued", |
| "message": "Job queued", |
| "updated_at": None, |
| "progress": 0, |
| "phase": "queued", |
| "current_epoch": 0, |
| "total_epochs": 0, |
| "device_label": "Detecting...", |
| "device_backend": "", |
| "device_name": "", |
| **JOBS[job_id], |
| } |
| if not job.get("total_epochs"): |
| job["total_epochs"] = 0 |
| can_stop = job.get("status") in {"queued", "running"} |
| summary = job.get("summary", {}) or {} |
| artifact_map = summary.get("artifacts", {}) or {} |
| run_dir = os.path.join(RUNS_DIR, job_id) |
| report_path = os.path.join(run_dir, artifact_map.get("classification_report", "classification_report.txt")) |
|
|
| visual_keys = ["training_history", "tsne", "confusion_matrix"] |
| download_keys = [ |
| "training_history", |
| "tsne", |
| "confusion_matrix", |
| "roc_curves", |
| "classification_report", |
| "final_model", |
| "best_class_model", |
| "best_recon_model", |
| ] |
|
|
| visual_artifacts = [] |
| download_artifacts = [] |
| for key in visual_keys + download_keys: |
| filename = artifact_map.get(key) |
| if not filename: |
| continue |
| file_path = os.path.join(run_dir, filename) |
| if not os.path.isfile(file_path): |
| continue |
| artifact_info = { |
| "key": key, |
| "filename": filename, |
| "url": f"/runs/{job_id}/{filename}", |
| "is_image": filename.lower().endswith((".png", ".jpg", ".jpeg", ".webp", ".gif")), |
| } |
| if key in visual_keys and artifact_info["is_image"]: |
| visual_artifacts.append(artifact_info) |
| if key in download_keys: |
| download_artifacts.append(artifact_info) |
|
|
| return templates.TemplateResponse( |
| request, |
| "status.html", |
| { |
| "request": request, |
| "job_id": job_id, |
| "job": job, |
| "summary": summary, |
| "can_stop": can_stop, |
| "visual_artifacts": visual_artifacts, |
| "download_artifacts": download_artifacts, |
| "report_text": _load_report_text(report_path), |
| }, |
| ) |
|
|
|
|
| @app.get("/api/status/{job_id}") |
| def status_api(job_id: str): |
| if job_id not in JOBS: |
| raise HTTPException(status_code=404, detail="Job not found") |
| return JOBS[job_id] |
|
|
|
|
| @app.get("/runs/{job_id}/{filename}") |
| def job_artifact(job_id: str, filename: str): |
| file_path = _safe_result_file(RUNS_DIR, job_id, filename) |
| return FileResponse(file_path, filename=os.path.basename(file_path)) |
|
|
|
|
| @app.get("/predictions/{prediction_id}/{filename}") |
| def prediction_artifact(prediction_id: str, filename: str): |
| file_path = _safe_result_file(PREDICTIONS_DIR, prediction_id, filename) |
| return FileResponse(file_path, filename=os.path.basename(file_path)) |
|
|