Spaces:
Runtime error
Runtime error
File size: 2,174 Bytes
3957394 c22cacc e93b18a 05a2172 3957394 2a1e204 05a2172 2a1e204 05a2172 e93b18a 05a2172 e93b18a 05a2172 2a1e204 e93b18a c22cacc 3957394 c22cacc 2a1e204 c22cacc 2a1e204 e93b18a 05a2172 e93b18a 05a2172 2a1e204 05a2172 3957394 e93b18a c22cacc 2a1e204 3957394 2a1e204 05a2172 2a1e204 e93b18a 05a2172 2a1e204 3957394 2a1e204 05a2172 e93b18a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | from fastapi import FastAPI, Body
from pydantic import BaseModel
from transformers import AutoTokenizer, BigBirdForSequenceClassification
from scipy.special import softmax
import torch
# Import AJCC API logic
from ajcc_api import stage_cancer
# Initialize FastAPI
app = FastAPI(title="TNM + AJCC Endpoint", version="2.0")
# Models (TNM) from Hugging Face
MODEL_T = "jkefeli/CancerStage_Classifier_T"
MODEL_N = "jkefeli/CancerStage_Classifier_N"
MODEL_M = "jkefeli/CancerStage_Classifier_M"
# Load tokenizer once
tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-BigBird")
# Load models once
model_T = BigBirdForSequenceClassification.from_pretrained(MODEL_T)
model_N = BigBirdForSequenceClassification.from_pretrained(MODEL_N)
model_M = BigBirdForSequenceClassification.from_pretrained(MODEL_M)
class Report(BaseModel):
text: str
cancer_type: str = "colon" # default, you can change it (breast, lung, etc.)
def predict_class(text, model):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=2048)
with torch.no_grad():
outputs = model(**inputs)
probs = softmax(outputs.logits.numpy(), axis=1)
pred_class = probs.argmax(axis=1)[0]
return int(pred_class)
@app.get("/")
def health_check():
return {"status": "running", "models": {"T": MODEL_T, "N": MODEL_N, "M": MODEL_M}}
@app.post("/predict_full")
def predict_full(report: Report = Body(...)):
text = report.text
cancer = report.cancer_type.lower()
try:
# 1) Predict numeric TNM classes
t_class = predict_class(text, model_T)
n_class = predict_class(text, model_N)
m_class = predict_class(text, model_M)
# 2) Convert numeric classes → TNM labels
T = f"T{t_class}"
N = f"N{n_class}"
M = f"M{m_class}"
# 3) Compute AJCC Stage
stage = stage_cancer(cancer_type=cancer, T=T, N=N, M=M)
return {
"input": text,
"TNM_prediction": {"T": T, "N": N, "M": M},
"AJCC_stage": stage,
"cancer_type_used": cancer
}
except Exception as e:
return {"error": str(e)} |