File size: 2,174 Bytes
3957394
c22cacc
e93b18a
05a2172
 
3957394
2a1e204
 
 
05a2172
2a1e204
05a2172
e93b18a
05a2172
 
 
 
 
e93b18a
05a2172
2a1e204
e93b18a
 
 
c22cacc
3957394
c22cacc
2a1e204
c22cacc
2a1e204
e93b18a
05a2172
 
e93b18a
05a2172
2a1e204
05a2172
3957394
 
e93b18a
c22cacc
2a1e204
 
3957394
2a1e204
 
05a2172
2a1e204
 
 
 
 
 
 
 
 
 
 
 
e93b18a
05a2172
 
2a1e204
 
 
3957394
2a1e204
05a2172
e93b18a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from fastapi import FastAPI, Body
from pydantic import BaseModel
from transformers import AutoTokenizer, BigBirdForSequenceClassification
from scipy.special import softmax
import torch

# Import AJCC API logic
from ajcc_api import stage_cancer

# Initialize FastAPI
app = FastAPI(title="TNM + AJCC Endpoint", version="2.0")

# Models (TNM) from Hugging Face
MODEL_T = "jkefeli/CancerStage_Classifier_T"
MODEL_N = "jkefeli/CancerStage_Classifier_N"
MODEL_M = "jkefeli/CancerStage_Classifier_M"

# Load tokenizer once
tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-BigBird")

# Load models once
model_T = BigBirdForSequenceClassification.from_pretrained(MODEL_T)
model_N = BigBirdForSequenceClassification.from_pretrained(MODEL_N)
model_M = BigBirdForSequenceClassification.from_pretrained(MODEL_M)

class Report(BaseModel):
    text: str
    cancer_type: str = "colon"   # default, you can change it (breast, lung, etc.)

def predict_class(text, model):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=2048)
    with torch.no_grad():
        outputs = model(**inputs)
        probs = softmax(outputs.logits.numpy(), axis=1)
        pred_class = probs.argmax(axis=1)[0]
    return int(pred_class)

@app.get("/")
def health_check():
    return {"status": "running", "models": {"T": MODEL_T, "N": MODEL_N, "M": MODEL_M}}

@app.post("/predict_full")
def predict_full(report: Report = Body(...)):
    text = report.text
    cancer = report.cancer_type.lower()

    try:
        # 1) Predict numeric TNM classes
        t_class = predict_class(text, model_T)
        n_class = predict_class(text, model_N)
        m_class = predict_class(text, model_M)

        # 2) Convert numeric classes → TNM labels
        T = f"T{t_class}"
        N = f"N{n_class}"
        M = f"M{m_class}"

        # 3) Compute AJCC Stage
        stage = stage_cancer(cancer_type=cancer, T=T, N=N, M=M)

        return {
            "input": text,
            "TNM_prediction": {"T": T, "N": N, "M": M},
            "AJCC_stage": stage,
            "cancer_type_used": cancer
        }

    except Exception as e:
        return {"error": str(e)}