MohamedTry commited on
Commit
2a1e204
·
verified ·
1 Parent(s): 933c078

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -14
app.py CHANGED
@@ -4,8 +4,11 @@ from transformers import AutoTokenizer, BigBirdForSequenceClassification
4
  from scipy.special import softmax
5
  import torch
6
 
 
 
 
7
  # Initialize FastAPI
8
- app = FastAPI(title="TNM Endpoint", version="1.0")
9
 
10
  # Models (TNM) from Hugging Face
11
  MODEL_T = "jkefeli/CancerStage_Classifier_T"
@@ -15,41 +18,52 @@ MODEL_M = "jkefeli/CancerStage_Classifier_M"
15
  # Load tokenizer once
16
  tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-BigBird")
17
 
18
- # Load models once (CPU mode)
19
  model_T = BigBirdForSequenceClassification.from_pretrained(MODEL_T)
20
  model_N = BigBirdForSequenceClassification.from_pretrained(MODEL_N)
21
  model_M = BigBirdForSequenceClassification.from_pretrained(MODEL_M)
22
 
23
  class Report(BaseModel):
24
  text: str
 
25
 
26
- def predict_stage(text, model):
27
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=2048)
28
  with torch.no_grad():
29
  outputs = model(**inputs)
30
  probs = softmax(outputs.logits.numpy(), axis=1)
31
  pred_class = probs.argmax(axis=1)[0]
32
- return {"class": int(pred_class), "probs": probs.tolist()}
33
 
34
  @app.get("/")
35
  def health_check():
36
  return {"status": "running", "models": {"T": MODEL_T, "N": MODEL_N, "M": MODEL_M}}
37
 
38
- @app.post("/predict_tnm")
39
- def predict_tnm(report: Report = Body(...)):
40
  text = report.text
 
 
41
  try:
42
- t_result = predict_stage(text, model_T)
43
- n_result = predict_stage(text, model_N)
44
- m_result = predict_stage(text, model_M)
 
 
 
 
 
 
 
 
 
45
 
46
  return {
47
  "input": text,
48
- "TNM_prediction": {
49
- "T": t_result,
50
- "N": n_result,
51
- "M": m_result
52
- }
53
  }
 
54
  except Exception as e:
55
  return {"error": str(e)}
 
4
  from scipy.special import softmax
5
  import torch
6
 
7
+ # Import AJCC API logic
8
+ from ajcc_api import stage_cancer
9
+
10
  # Initialize FastAPI
11
+ app = FastAPI(title="TNM + AJCC Endpoint", version="2.0")
12
 
13
  # Models (TNM) from Hugging Face
14
  MODEL_T = "jkefeli/CancerStage_Classifier_T"
 
18
  # Load tokenizer once
19
  tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-BigBird")
20
 
21
+ # Load models once
22
  model_T = BigBirdForSequenceClassification.from_pretrained(MODEL_T)
23
  model_N = BigBirdForSequenceClassification.from_pretrained(MODEL_N)
24
  model_M = BigBirdForSequenceClassification.from_pretrained(MODEL_M)
25
 
26
  class Report(BaseModel):
27
  text: str
28
+ cancer_type: str = "colon" # default, you can change it (breast, lung, etc.)
29
 
30
+ def predict_class(text, model):
31
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=2048)
32
  with torch.no_grad():
33
  outputs = model(**inputs)
34
  probs = softmax(outputs.logits.numpy(), axis=1)
35
  pred_class = probs.argmax(axis=1)[0]
36
+ return int(pred_class)
37
 
38
  @app.get("/")
39
  def health_check():
40
  return {"status": "running", "models": {"T": MODEL_T, "N": MODEL_N, "M": MODEL_M}}
41
 
42
+ @app.post("/predict_full")
43
+ def predict_full(report: Report = Body(...)):
44
  text = report.text
45
+ cancer = report.cancer_type.lower()
46
+
47
  try:
48
+ # 1) Predict numeric TNM classes
49
+ t_class = predict_class(text, model_T)
50
+ n_class = predict_class(text, model_N)
51
+ m_class = predict_class(text, model_M)
52
+
53
+ # 2) Convert numeric classes → TNM labels
54
+ T = f"T{t_class}"
55
+ N = f"N{n_class}"
56
+ M = f"M{m_class}"
57
+
58
+ # 3) Compute AJCC Stage
59
+ stage = stage_cancer(cancer_type=cancer, T=T, N=N, M=M)
60
 
61
  return {
62
  "input": text,
63
+ "TNM_prediction": {"T": T, "N": N, "M": M},
64
+ "AJCC_stage": stage,
65
+ "cancer_type_used": cancer
 
 
66
  }
67
+
68
  except Exception as e:
69
  return {"error": str(e)}