File size: 4,613 Bytes
3142c97 237e309 3142c97 a30a7a5 3142c97 237e309 3142c97 237e309 3142c97 a30a7a5 3142c97 237e309 3142c97 237e309 3142c97 a30a7a5 3142c97 237e309 3142c97 237e309 a30a7a5 237e309 3142c97 237e309 a30a7a5 3142c97 a30a7a5 3142c97 a30a7a5 3142c97 237e309 3142c97 a30a7a5 3142c97 237e309 3142c97 237e309 3142c97 237e309 3142c97 237e309 3142c97 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn as nn
import spacy
import spacy.cli
import time
import os
app = FastAPI(
title="Clinical Extractive Summarization",
description="SciBERT + BERTsum Fine-Tuned Engine for Medical Reports"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# --- ARCHITECTURE DEFINITION ---
class BioExtractor(nn.Module):
def __init__(self, model_name):
super(BioExtractor, self).__init__()
self.bert = AutoModel.from_pretrained(model_name)
# The classification layer that predicts sentence salience [cite: 279]
self.classifier = nn.Linear(768, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
cls_output = outputs.last_hidden_state[:, 0, :]
return self.sigmoid(self.classifier(cls_output))
# Global variables to cache models in memory
tokenizer = None
model = None
nlp = None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class ReportRequest(BaseModel):
text: str
num_sentences: int = 3
@app.get("/")
def health_check():
return {
"status": "Engine is running",
"message": "Send POST requests to /api/summarize",
"docs": "Visit /docs for the Swagger UI"
}
@app.post("/api/summarize")
def summarize_medical_report(request: ReportRequest):
start_time = time.time()
global tokenizer, model, nlp, device
if model is None:
print("Initializing Fine-Tuned SciBERT and SpaCy...")
# Load the base tokenizer
model_name = "allenai/scibert_scivocab_uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Instantiate your custom architecture
model = BioExtractor(model_name)
# Load the trained weights from the uploaded .pt file
model_path = "med_summarizer_trained.pt"
if os.path.exists(model_path):
print(f"Loading fine-tuned weights from {model_path}...")
# map_location ensures it works even if Hugging Face runs on a CPU space
model.load_state_dict(torch.load(model_path, map_location=device))
else:
print(f"WARNING: {model_path} not found! Upload it to your Space.")
model.to(device)
model.eval() # Lock the model for inference
try:
nlp = spacy.load("en_core_web_sm")
except OSError:
print("Downloading SpaCy English model...")
spacy.cli.download("en_core_web_sm")
nlp = spacy.load("en_core_web_sm")
print("Models loaded successfully!")
# 1. Safely split text into sentences using SpaCy NLP
doc = nlp(request.text)
# Filter out extremely short strings just like your Colab script
sentences = [sent.text.strip() for sent in doc.sents if len(sent.text.strip()) > 10]
# Edge case: Report is too short to summarize
if len(sentences) <= request.num_sentences:
return {"summary": request.text, "metadata": {"status": "too_short"}}
# 2. Get probability scores for each sentence using the fine-tuned model
scores = []
with torch.no_grad():
for sent in sentences:
inputs = tokenizer(sent, return_tensors="pt", truncation=True, padding='max_length', max_length=128).to(device)
output = model(inputs['input_ids'], inputs['attention_mask'])
scores.append(output.item())
# 3. Rank and select the top N sentences
# Enumerate keeps track of the original sentence index (e.g., (index, score))
scored_sentences = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
top_indices = [idx for idx, score in scored_sentences[:request.num_sentences]]
# 4. Sort indices chronologically to maintain original report flow [cite: 248]
top_indices_sorted = sorted(top_indices)
final_summary = " ".join([sentences[i] for i in top_indices_sorted])
process_time = round((time.time() - start_time) * 1000, 2)
return {
"summary": final_summary,
"metadata": {
"processing_time_ms": process_time,
"original_length": len(sentences),
"summary_length": len(top_indices_sorted),
"engine": "SciBERT + BERTsum Fine-Tuned"
}
} |