File size: 4,613 Bytes
3142c97
 
 
 
 
237e309
3142c97
a30a7a5
3142c97
237e309
3142c97
 
 
237e309
3142c97
 
 
 
a30a7a5
3142c97
 
 
 
237e309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3142c97
 
 
237e309
3142c97
 
 
 
 
a30a7a5
 
 
 
 
 
 
 
3142c97
 
 
 
237e309
3142c97
237e309
a30a7a5
237e309
3142c97
 
237e309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a30a7a5
3142c97
 
 
a30a7a5
3142c97
 
a30a7a5
3142c97
 
 
 
237e309
 
3142c97
a30a7a5
3142c97
 
 
237e309
 
 
 
 
 
 
3142c97
237e309
 
 
 
3142c97
237e309
 
 
3142c97
 
 
 
 
 
 
 
237e309
 
3142c97
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn as nn
import spacy
import spacy.cli
import time
import os

app = FastAPI(
    title="Clinical Extractive Summarization",
    description="SciBERT + BERTsum Fine-Tuned Engine for Medical Reports"
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# --- ARCHITECTURE DEFINITION ---
class BioExtractor(nn.Module):
    def __init__(self, model_name):
        super(BioExtractor, self).__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        # The classification layer that predicts sentence salience [cite: 279]
        self.classifier = nn.Linear(768, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]
        return self.sigmoid(self.classifier(cls_output))

# Global variables to cache models in memory
tokenizer = None
model = None
nlp = None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class ReportRequest(BaseModel):
    text: str
    num_sentences: int = 3

@app.get("/")
def health_check():
    return {
        "status": "Engine is running", 
        "message": "Send POST requests to /api/summarize",
        "docs": "Visit /docs for the Swagger UI"
    }

@app.post("/api/summarize")
def summarize_medical_report(request: ReportRequest):
    start_time = time.time()
    
    global tokenizer, model, nlp, device
    if model is None:
        print("Initializing Fine-Tuned SciBERT and SpaCy...")
        
        # Load the base tokenizer
        model_name = "allenai/scibert_scivocab_uncased"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        # Instantiate your custom architecture
        model = BioExtractor(model_name)
        
        # Load the trained weights from the uploaded .pt file
        model_path = "med_summarizer_trained.pt" 
        if os.path.exists(model_path):
            print(f"Loading fine-tuned weights from {model_path}...")
            # map_location ensures it works even if Hugging Face runs on a CPU space
            model.load_state_dict(torch.load(model_path, map_location=device))
        else:
            print(f"WARNING: {model_path} not found! Upload it to your Space.")
            
        model.to(device)
        model.eval() # Lock the model for inference
        
        try:
            nlp = spacy.load("en_core_web_sm")
        except OSError:
            print("Downloading SpaCy English model...")
            spacy.cli.download("en_core_web_sm")
            nlp = spacy.load("en_core_web_sm")
            
        print("Models loaded successfully!")

    # 1. Safely split text into sentences using SpaCy NLP
    doc = nlp(request.text)
    # Filter out extremely short strings just like your Colab script
    sentences = [sent.text.strip() for sent in doc.sents if len(sent.text.strip()) > 10]
    
    # Edge case: Report is too short to summarize
    if len(sentences) <= request.num_sentences:
        return {"summary": request.text, "metadata": {"status": "too_short"}}

    # 2. Get probability scores for each sentence using the fine-tuned model
    scores = []
    with torch.no_grad():
        for sent in sentences:
            inputs = tokenizer(sent, return_tensors="pt", truncation=True, padding='max_length', max_length=128).to(device)
            output = model(inputs['input_ids'], inputs['attention_mask'])
            scores.append(output.item())

    # 3. Rank and select the top N sentences
    # Enumerate keeps track of the original sentence index (e.g., (index, score))
    scored_sentences = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
    top_indices = [idx for idx, score in scored_sentences[:request.num_sentences]]

    # 4. Sort indices chronologically to maintain original report flow [cite: 248]
    top_indices_sorted = sorted(top_indices)
    final_summary = " ".join([sentences[i] for i in top_indices_sorted])
    
    process_time = round((time.time() - start_time) * 1000, 2)
    
    return {
        "summary": final_summary,
        "metadata": {
            "processing_time_ms": process_time,
            "original_length": len(sentences),
            "summary_length": len(top_indices_sorted),
            "engine": "SciBERT + BERTsum Fine-Tuned"
        }
    }