from fastapi import FastAPI, Body from pydantic import BaseModel from transformers import AutoTokenizer, BigBirdForSequenceClassification from scipy.special import softmax import torch # Import AJCC API logic from ajcc_api import stage_cancer # Initialize FastAPI app = FastAPI(title="TNM + AJCC Endpoint", version="2.0") # Models (TNM) from Hugging Face MODEL_T = "jkefeli/CancerStage_Classifier_T" MODEL_N = "jkefeli/CancerStage_Classifier_N" MODEL_M = "jkefeli/CancerStage_Classifier_M" # Load tokenizer once tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-BigBird") # Load models once model_T = BigBirdForSequenceClassification.from_pretrained(MODEL_T) model_N = BigBirdForSequenceClassification.from_pretrained(MODEL_N) model_M = BigBirdForSequenceClassification.from_pretrained(MODEL_M) class Report(BaseModel): text: str cancer_type: str = "colon" # default, you can change it (breast, lung, etc.) def predict_class(text, model): inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=2048) with torch.no_grad(): outputs = model(**inputs) probs = softmax(outputs.logits.numpy(), axis=1) pred_class = probs.argmax(axis=1)[0] return int(pred_class) @app.get("/") def health_check(): return {"status": "running", "models": {"T": MODEL_T, "N": MODEL_N, "M": MODEL_M}} @app.post("/predict_full") def predict_full(report: Report = Body(...)): text = report.text cancer = report.cancer_type.lower() try: # 1) Predict numeric TNM classes t_class = predict_class(text, model_T) n_class = predict_class(text, model_N) m_class = predict_class(text, model_M) # 2) Convert numeric classes → TNM labels T = f"T{t_class}" N = f"N{n_class}" M = f"M{m_class}" # 3) Compute AJCC Stage stage = stage_cancer(cancer_type=cancer, T=T, N=N, M=M) return { "input": text, "TNM_prediction": {"T": T, "N": N, "M": M}, "AJCC_stage": stage, "cancer_type_used": cancer } except Exception as e: return {"error": str(e)}