TNM / app.py
MohamedTry's picture
Update app.py
2a1e204 verified
from fastapi import FastAPI, Body
from pydantic import BaseModel
from transformers import AutoTokenizer, BigBirdForSequenceClassification
from scipy.special import softmax
import torch
# Import AJCC API logic
from ajcc_api import stage_cancer
# Initialize FastAPI
app = FastAPI(title="TNM + AJCC Endpoint", version="2.0")
# Models (TNM) from Hugging Face
MODEL_T = "jkefeli/CancerStage_Classifier_T"
MODEL_N = "jkefeli/CancerStage_Classifier_N"
MODEL_M = "jkefeli/CancerStage_Classifier_M"
# Load tokenizer once
tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-BigBird")
# Load models once
model_T = BigBirdForSequenceClassification.from_pretrained(MODEL_T)
model_N = BigBirdForSequenceClassification.from_pretrained(MODEL_N)
model_M = BigBirdForSequenceClassification.from_pretrained(MODEL_M)
class Report(BaseModel):
text: str
cancer_type: str = "colon" # default, you can change it (breast, lung, etc.)
def predict_class(text, model):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=2048)
with torch.no_grad():
outputs = model(**inputs)
probs = softmax(outputs.logits.numpy(), axis=1)
pred_class = probs.argmax(axis=1)[0]
return int(pred_class)
@app.get("/")
def health_check():
return {"status": "running", "models": {"T": MODEL_T, "N": MODEL_N, "M": MODEL_M}}
@app.post("/predict_full")
def predict_full(report: Report = Body(...)):
text = report.text
cancer = report.cancer_type.lower()
try:
# 1) Predict numeric TNM classes
t_class = predict_class(text, model_T)
n_class = predict_class(text, model_N)
m_class = predict_class(text, model_M)
# 2) Convert numeric classes → TNM labels
T = f"T{t_class}"
N = f"N{n_class}"
M = f"M{m_class}"
# 3) Compute AJCC Stage
stage = stage_cancer(cancer_type=cancer, T=T, N=N, M=M)
return {
"input": text,
"TNM_prediction": {"T": T, "N": N, "M": M},
"AJCC_stage": stage,
"cancer_type_used": cancer
}
except Exception as e:
return {"error": str(e)}