import torch import torch.nn as nn from transformers import T5ForConditionalGeneration class PointerGeneratorT5(nn.Module): def __init__(self, model_name='t5-base'): super().__init__() from transformers import T5ForConditionalGeneration self.t5 = T5ForConditionalGeneration.from_pretrained(model_name) self.config = self.t5.config # Pointer-generator components self.p_gen_linear = nn.Linear( self.config.d_model * 2, # context + decoder state 1 ) def forward(self, input_ids, attention_mask, decoder_input_ids=None): return self.t5( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, output_hidden_states=True, output_attentions=True, return_dict=True ) def generate_with_pointer( self, input_ids, attention_mask, tokenizer, max_length=100, temperature=0.7 ): batch_size = input_ids.size(0) device = input_ids.device # Start with decoder start token decoder_input_ids = torch.full( (batch_size, 1), self.t5.config.decoder_start_token_id, dtype=torch.long, device=device ) generated_tokens = [] source_tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) for _ in range(max_length): # Forward pass outputs = self.forward( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids ) # Get logits and hidden states logits = outputs.logits[:, -1, :] # [batch, vocab] decoder_hidden = outputs.decoder_hidden_states[-1][:, -1, :] # Last layer, last token # Get encoder outputs (context) encoder_hidden = outputs.encoder_last_hidden_state # [batch, seq, hidden] # Calculate attention weights over source cross_attention = outputs.cross_attentions[-1] # [batch, heads, dec_len, enc_len] attention_weights = cross_attention[:, :, -1, :].mean(dim=1) # Average over heads [batch, enc_len] # Calculate p_gen (probability of generating vs copying) context_vector = torch.bmm( attention_weights.unsqueeze(1), # [batch, 1, enc_len] encoder_hidden # [batch, enc_len, hidden] ).squeeze(1) # [batch, hidden] p_gen_input = torch.cat([context_vector, decoder_hidden], dim=-1) p_gen = torch.sigmoid(self.p_gen_linear(p_gen_input)) # [batch, 1] # Get vocabulary distribution vocab_dist = torch.softmax(logits / temperature, dim=-1) # [batch, vocab] # Create pointer distribution over source tokens pointer_dist = torch.zeros_like(vocab_dist) attention_weights_expanded = attention_weights[0] # [enc_len] for i, token_id in enumerate(input_ids[0]): if i < len(attention_weights_expanded): pointer_dist[0, token_id] += attention_weights_expanded[i] # Combine distributions using p_gen final_dist = p_gen * vocab_dist + (1 - p_gen) * pointer_dist # Sample next token next_token = torch.argmax(final_dist, dim=-1) # Stop if EOS token if next_token.item() == self.t5.config.eos_token_id: break generated_tokens.append(next_token.item()) # Update decoder input decoder_input_ids = torch.cat([ decoder_input_ids, next_token.unsqueeze(0) ], dim=-1) return generated_tokens, p_gen.item() class MedicalQAProcessor: def __init__(self, model, tokenizer, device, nlp, medical_terms=None): self.model = model self.tokenizer = tokenizer self.device = device self.nlp = nlp self.medical_terms = medical_terms or set() def generate_answer(self, question, context, max_length=100, use_sentence_structure=True): if use_sentence_structure: input_text = f"answer in complete sentence. question: {question} context: {context}" else: input_text = f"question: {question} context: {context}" inputs = self.tokenizer( input_text, max_length=512, truncation=True, return_tensors='pt' ).to(self.device) with torch.no_grad(): generated_ids, p_gen_score = self.model.generate_with_pointer( input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], tokenizer=self.tokenizer, max_length=max_length, temperature=0.7 ) answer = self.tokenizer.decode(generated_ids, skip_special_tokens=True) if use_sentence_structure and answer: answer = self.ensure_sentence_structure(answer, question) return { 'answer': answer, 'p_gen_score': f"{p_gen_score:.3f}", 'interpretation': 'Higher p_gen = more generation, Lower = more copying' } def extract_subject_umls(self, question): # Extract medical entities with priority ranking doc = self.nlp(question) question_lower = question.lower() entities = [(ent.text, ent.label_, ent.start_char) for ent in doc.ents] exclude_terms = {'age', 'time', 'date', 'frequency', 'often', 'monitored', 'diagnosed', 'treated', 'caused', 'prevented', 'managed', 'controlled', 'positive', 'negative', 'men', 'women', 'patients', 'people', 'individuals', 'initially', 'stable', 'reduce', 'increase', 'decrease', 'checked', 'happens', 'begin', 'cured', 'annual', 'risk', 'common', 'size', 'tumor defines stage', 'median', 'survival', 'false', 'screened', 'problem', 'target', 'reverses', 'dosing', 'measure', 'reduction'} condition_keywords = {'diabetes', 'cancer', 'disease', 'disorder', 'syndrome', 'hypertension', 'asthma', 'tuberculosis', 'alzheimer', 'migraine', 'hypothyroidism', 'type 1', 'type 2', 'ra ', 'rheumatoid arthritis', 'osteoarthritis', 'warfarin', 'methotrexate', 'inr', 'nsclc', 'lung cancer', 'stage ia', 'stage iv', 'immunotherapy', 'pregnancy'} medical_entities = [] for text, label, start in entities: text_lower = text.lower() if text_lower in exclude_terms or any(ex in text_lower for ex in exclude_terms): continue priority = 0 if any(keyword in text_lower for keyword in condition_keywords): priority = 2 elif label == 'ENTITY' and len(text.split()) > 1: priority = 1 medical_entities.append((text, priority, start)) medical_entities.sort(key=lambda x: (-x[1], x[2])) if medical_entities: return medical_entities[0][0].title() if self.medical_terms: for term in self.medical_terms: if term in question_lower: return term.title() noun_chunks = [chunk.text for chunk in doc.noun_chunks] for chunk in noun_chunks: chunk_lower = chunk.lower() if chunk_lower not in exclude_terms and chunk_lower not in ['what', 'how', 'when', 'where', 'which', 'who', 'why']: if len(chunk.split()) <= 4: return chunk.title() return "It" def ensure_sentence_structure(self, answer, question): answer = answer.strip() question_lower = question.lower() # If already well-formed if len(answer.split()) > 8 and answer[0].isupper() and answer[-1] in '.!?': return answer subject = self.extract_subject_umls(question) # Can or Does Cure Questions if question_lower.startswith('can ') or (question_lower.startswith('does ') and 'cure' in question_lower): if 'cure' in question_lower or 'cured' in question_lower: if 'pregnancy' in question_lower: answer = f"No, pregnancy does not cure {subject.lower()}, though symptoms may temporarily improve." elif 'not' in answer.lower() or 'no' in answer.lower() or 'possible' in answer.lower(): answer = f"No, {subject.lower()} cannot currently be cured, requiring lifelong management." else: answer = f"Yes, {answer}." elif 'used' in question_lower and 'pregnancy' in question_lower: if 'contraindicated' in answer.lower() or 'not' in answer.lower() or 'no' in answer.lower(): answer = f"No, {subject} is contraindicated during pregnancy." else: answer = f"Yes, {subject} can be used during pregnancy." else: if not answer.lower().startswith('yes') and not answer.lower().startswith('no'): answer = f"Yes, {answer}." if not answer.endswith('.'): answer = answer + '.' # Do or Does Questions elif question_lower.startswith('do ') or question_lower.startswith('does '): # Check for "all" in question if 'all' in question_lower or 'everyone' in question_lower: if 'no' in answer.lower() or 'not' in answer.lower() or answer.startswith('No'): answer = f"No, not all patients show this characteristic." elif '%' in answer or 'only' in answer.lower(): answer = f"No, only {answer} of patients show this response." else: answer = f"No, {answer}." # Difference/comparison questions elif 'differ' in question_lower or 'difference' in question_lower: if not answer[0].isupper(): answer = answer[0].upper() + answer[1:] answer = f"The key difference is that {subject.lower()} is {answer.lower()}." # Effect questions (increase/decrease) elif 'increase or decrease' in question_lower: if answer.lower() in ['increase', 'decrease']: verb = 'increase' if 'increase' in answer.lower() else 'decrease' answer = f"Antibiotics {verb} warfarin effect." else: answer = f"{answer}." # Percentage/statistic questions elif '%' in answer or (len(answer.split()) <= 3 and any(char.isdigit() for char in answer)): if 'respond' in question_lower: answer = f"No, only {answer} of patients respond to treatment." else: answer = f"Yes, approximately {answer}." # Negative answers elif answer.lower() in ['no', 'not', 'unclear', 'unknown']: answer = f"No, the exact cause is {answer.lower()}." else: if not answer[0].isupper(): answer = answer[0].upper() + answer[1:] if not answer.endswith('.'): answer = answer + '.' # Is Questions elif question_lower.startswith('is ') and '?' in question: # "Is X more common in Y or Z?" if 'more common' in question_lower and ('men' in question_lower or 'women' in question_lower): if answer.lower() in ['women', 'men']: gender = answer.lower() other = 'men' if gender == 'women' else 'women' answer = f"{subject} is more common in {gender} than {other}." else: answer = f"{subject} affects {answer}." # "Is X specific for Y?" elif 'specific' in question_lower: if len(answer.split()) < 8: answer = f"No, {subject.lower()} is not entirely specific." elif not answer[0].isupper(): answer = answer[0].upper() + answer[1:] # General yes/no elif len(answer.split()) > 5: if not answer[0].isupper(): answer = answer[0].upper() + answer[1:] else: if 'chronic' in question_lower: answer = f"Yes, {subject.lower()} is a chronic condition." else: answer = f"Yes, {answer}." if not answer.endswith('.'): answer = answer + '.' # How Does or Do Questions elif question_lower.startswith('how does') or question_lower.startswith('how do'): if 'differ' in question_lower: if len(answer.split()) < 6: answer = f"The main difference is that one is {answer.lower()}." else: if not answer[0].isupper(): answer = answer[0].upper() + answer[1:] elif 'survival' in question_lower and 'differ' in question_lower: if not answer[0].isupper(): answer = answer[0].upper() + answer[1:] else: if not answer[0].isupper(): answer = answer[0].upper() + answer[1:] if not answer.endswith('.'): answer = answer + '.' # How much or How many elif question_lower.startswith('how much') or question_lower.startswith('how many'): if 'reduce' in question_lower or 'life expectancy' in question_lower: if answer.replace('%', '').replace('-', '').replace('years', '').strip().replace(' ', '').isdigit() or 'year' in answer: answer = f"Untreated {subject.lower()} reduces life expectancy by {answer}." else: answer = f"It reduces mortality by {answer}." elif 'dose reduction' in question_lower or 'reduction' in question_lower: answer = f"A dose reduction of {answer} is needed for certain genetic variants." else: answer = f"The amount is {answer}." if not answer.endswith('.'): answer = answer + '.' # How Long or How Fast elif question_lower.startswith('how long') or question_lower.startswith('how fast'): if 'stiffness' in question_lower or 'last' in question_lower: answer = f"Morning stiffness should last {answer} to suggest RA." elif 'reverse' in question_lower: answer = f"Vitamin K reverses warfarin in {answer}." else: answer = f"The duration is {answer}." if not answer.endswith('.'): answer = answer + '.' # HOW Often or How Frequently elif question_lower.startswith('how often') or question_lower.startswith('how frequently'): if 'checked' in answer.lower() or 'monitored' in answer.lower() or 'should be done' in answer.lower(): if not answer[0].isupper(): answer = answer[0].upper() + answer[1:] else: if 'inr' in question_lower: answer = f"INR should be monitored {answer}." else: answer = f"The frequency is {answer}." if not answer.endswith('.'): answer = answer + '.' # How Common elif 'how common' in question_lower: if '%' in answer or any(char.isdigit() for char in answer): # Remove duplicate phrases answer = answer.replace('of patients per year of patients per year', 'of patients per year') answer = f"The incidence is {answer}." else: answer = f"The frequency is {answer}." if not answer.endswith('.'): answer = answer + '.' # At what age elif 'at what age' in question_lower or 'what age' in question_lower: if 'ra' in question_lower.replace('RA', 'ra'): subject = 'RA' if 'between' in answer or 'ages of' in answer or ('-' in answer and any(c.isdigit() for c in answer)): answer = f"{subject} typically begins between ages {answer.replace('between ages', '').strip()}." elif any(char.isdigit() for char in answer): answer = f"{subject} typically begins at {answer}." else: answer = f"The typical age is {answer}." if not answer.endswith('.'): answer = answer + '.' # When questions elif question_lower.startswith('when '): if 'begin' in question_lower or 'start' in question_lower: if 'this occurs' in answer.lower(): answer = answer.replace('This occurs', 'Treatment should begin within').replace('this occurs', 'within') elif any(char.isdigit() for char in answer): answer = f"Treatment should begin within {answer} of symptom onset." else: answer = f"Treatment should begin {answer}." elif 'used' in question_lower: if 'this occurs' in answer.lower(): answer = answer.replace('This occurs', 'They are used for').replace('this occurs', 'for') else: answer = f"They are used for {answer}." elif 'pcc' in question_lower or 'reversal' in question_lower: if 'this occurs' in answer.lower(): answer = answer.replace('This occurs', 'PCC is used for').replace('this occurs', 'for') else: answer = f"PCC is used for {answer}." else: if 'this occurs' in answer.lower(): answer = answer.replace('This occurs', 'This happens at').replace('this occurs', 'at') else: answer = f"This occurs {answer}." if not answer.endswith('.'): answer = answer + '.' # What percentage or What is the [rate] elif 'what percentage' in question_lower or 'what is the survival rate' in question_lower or 'what is the false positive rate' in question_lower or 'what remission rate' in question_lower or 'what is the annual risk' in question_lower: if '%' in answer or answer.replace('.', '').replace('-', '').strip().isdigit(): if 'survival rate' in question_lower: answer = f"The survival rate is {answer}." elif 'remission' in question_lower: answer = f"The remission rate is {answer} with early treatment." elif 'false positive' in question_lower: answer = f"The false positive rate is {answer}." elif 'risk' in question_lower: answer = f"The annual risk is {answer}." elif 'test negative' in question_lower or 'negative' in question_lower: answer = f"Approximately {answer} of patients test negative." elif 'test positive' in question_lower or 'positive' in question_lower or 'have positive' in question_lower: answer = f"Approximately {answer} of patients test positive." elif 'respond' in question_lower: answer = f"Approximately {answer} of patients respond." else: answer = f"The percentage is {answer}." else: answer = f"The percentage is {answer}." if not answer.endswith('.'): answer = answer + '.' # What size or what is the median elif 'what size' in question_lower or 'what is the median' in question_lower: if 'size' in question_lower: answer = f"Stage IA NSCLC is defined as tumors ≤{answer}." elif 'median' in question_lower: answer = f"The median survival is {answer} with immunotherapy." if not answer.endswith('.'): answer = answer + '.' # What is or What are elif question_lower.startswith('what is') or question_lower.startswith('what are'): # Definition questions if question_lower.startswith('what is the therapeutic') or question_lower.startswith('what is seronegative'): if answer.replace('.', '').replace('-', '').replace('/', '').replace(' ', '').replace('%', '').isdigit() or len(answer.split()) < 4: if 'therapeutic window' in question_lower: answer = f"The therapeutic window is the narrow range between effective and toxic doses." elif 'seronegative' in question_lower: answer = f"Seronegative RA refers to cases where patients test negative for rheumatoid factor." else: answer = f"It is defined as {answer}." else: if not answer[0].isupper(): answer = answer[0].upper() + answer[1:] # "What are extra-articular manifestations?" elif 'extra-articular' in question_lower or 'manifestations' in question_lower: if not answer[0].isupper(): answer = answer[0].upper() + answer[1:] if len(answer.split()) < 6: answer = f"Extra-articular manifestations are symptoms affecting the lungs, heart, or eyes." else: # Already has good structure pass # "What does X measure?" elif 'measure' in question_lower: if len(answer.split()) < 4: if 'tnm' in question_lower: answer = f"The TNM system measures tumor size (T), lymph node involvement (N), and metastasis (M)." elif 'inr' in question_lower: answer = f"INR measures the blood's clotting time and therapeutic effect of warfarin." else: answer = f"It measures {answer}." else: if not answer[0].isupper(): answer = answer[0].upper() + answer[1:] # "What reverses X immediately?" elif 'reverse' in question_lower and 'immediately' in question_lower: if len(answer.split()) < 4: answer = f"{answer} reverses warfarin immediately but has a short duration." else: if not answer[0].isupper(): answer = answer[0].upper() + answer[1:] # "What reverses X?" (general) elif 'reverse' in question_lower: if len(answer.split()) < 4: answer = f"{answer} reverses warfarin." else: if not answer[0].isupper(): answer = answer[0].upper() + answer[1:] # First-line/treatment questions elif 'first-line' in question_lower or 'dmards' in question_lower: if len(answer.split()) < 3: answer = f"The first-line DMARD is {answer}." else: answer = f"The first-line treatments include {answer}." # Lab test questions elif 'lab test' in question_lower or 'tests' in question_lower: if not answer[0].isupper(): answer = answer[0].upper() + answer[1:] if len(answer.split()) > 10: pass else: answer = f"The tests include {answer}." # "What happens during X?" elif 'happen' in question_lower: if len(answer.split()) < 6: answer = f"During pregnancy, {answer}." else: if not answer[0].isupper(): answer = answer[0].upper() + answer[1:] # "What is used instead of X?" elif 'instead' in question_lower or 'alternative' in question_lower: if len(answer.split()) < 4: answer = f"The alternative is low-molecular-weight {answer}." else: answer = f"{answer} is used as an alternative." # "What is the problem with X?" elif 'problem' in question_lower: if not answer[0].isupper(): answer = answer[0].upper() + answer[1:] answer = f"The problem is {answer.lower()}." # "What is the target INR?" elif 'target' in question_lower and 'inr' in question_lower: answer = f"The target INR range is {answer}." # Generic what questions else: if not answer[0].isupper(): answer = answer[0].upper() + answer[1:] if not answer.endswith('.'): answer = f"{answer}." # Who questions elif question_lower.startswith('who '): if 'screened' in question_lower: if len(answer.split()) < 4: answer = f"High-risk individuals aged 50-80 with 30+ pack-year smoking history should be screened." else: if not answer[0].isupper(): answer = answer[0].upper() + answer[1:] else: if not answer[0].isupper(): answer = answer[0].upper() + answer[1:] if not answer.endswith('.'): answer = answer + '.' # Why question elif question_lower.startswith('why '): if 'avoided' in question_lower or 'dangerous' in question_lower: if len(answer.split()) < 5: if 'pregnancy' in question_lower: answer = f"Warfarin is avoided in pregnancy {answer.lower()}." elif 'nsaid' in question_lower: answer = f"NSAIDs are dangerous with warfarin because they {answer.lower()}." else: answer = f"This is because {answer}." else: if not answer[0].isupper(): answer = answer[0].upper() + answer[1:] else: answer = f"This is because {answer}." if not answer.endswith('.'): answer = answer + '.' # Should questions elif question_lower.startswith('should '): if 'avoid' in question_lower: if not answer[0].isupper(): answer = answer[0].upper() + answer[1:] else: if not answer[0].isupper(): answer = answer[0].upper() + answer[1:] if not answer.endswith('.'): answer = answer + '.' # === FALLBACK === else: if not answer[0].isupper(): answer = answer[0].upper() + answer[1:] if not answer.endswith('.'): answer = answer + '.' # Final check if not answer[-1] in '.!?': answer = answer + '.' return answer