Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Body | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, BigBirdForSequenceClassification | |
| from scipy.special import softmax | |
| import torch | |
| # Import AJCC API logic | |
| from ajcc_api import stage_cancer | |
| # Initialize FastAPI | |
| app = FastAPI(title="TNM + AJCC Endpoint", version="2.0") | |
| # Models (TNM) from Hugging Face | |
| MODEL_T = "jkefeli/CancerStage_Classifier_T" | |
| MODEL_N = "jkefeli/CancerStage_Classifier_N" | |
| MODEL_M = "jkefeli/CancerStage_Classifier_M" | |
| # Load tokenizer once | |
| tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-BigBird") | |
| # Load models once | |
| model_T = BigBirdForSequenceClassification.from_pretrained(MODEL_T) | |
| model_N = BigBirdForSequenceClassification.from_pretrained(MODEL_N) | |
| model_M = BigBirdForSequenceClassification.from_pretrained(MODEL_M) | |
| class Report(BaseModel): | |
| text: str | |
| cancer_type: str = "colon" # default, you can change it (breast, lung, etc.) | |
| def predict_class(text, model): | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=2048) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = softmax(outputs.logits.numpy(), axis=1) | |
| pred_class = probs.argmax(axis=1)[0] | |
| return int(pred_class) | |
| def health_check(): | |
| return {"status": "running", "models": {"T": MODEL_T, "N": MODEL_N, "M": MODEL_M}} | |
| def predict_full(report: Report = Body(...)): | |
| text = report.text | |
| cancer = report.cancer_type.lower() | |
| try: | |
| # 1) Predict numeric TNM classes | |
| t_class = predict_class(text, model_T) | |
| n_class = predict_class(text, model_N) | |
| m_class = predict_class(text, model_M) | |
| # 2) Convert numeric classes → TNM labels | |
| T = f"T{t_class}" | |
| N = f"N{n_class}" | |
| M = f"M{m_class}" | |
| # 3) Compute AJCC Stage | |
| stage = stage_cancer(cancer_type=cancer, T=T, N=N, M=M) | |
| return { | |
| "input": text, | |
| "TNM_prediction": {"T": T, "N": N, "M": M}, | |
| "AJCC_stage": stage, | |
| "cancer_type_used": cancer | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} |