File size: 5,752 Bytes
72533b2
 
 
 
 
 
 
 
 
 
 
 
 
831deda
72533b2
 
 
 
 
831deda
72533b2
 
 
831deda
72533b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
831deda
 
 
 
72533b2
 
831deda
72533b2
 
 
 
 
 
 
831deda
 
 
72533b2
 
 
 
831deda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72533b2
831deda
72533b2
 
 
 
 
831deda
72533b2
831deda
 
72533b2
 
 
 
 
 
 
 
 
 
 
 
 
831deda
72533b2
 
 
 
 
831deda
72533b2
831deda
72533b2
831deda
 
 
72533b2
831deda
72533b2
 
831deda
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import chromadb
import traceback
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from retriever import retrieve
from utils import build_prompt, refine_response


# ============================================================================
# LAZY-LOAD MODELS AND VECTOR STORE (load on first use, not at import)
# ============================================================================

_vector_store = None
_finetuned_llm = None
_base_model = None

def get_vector_store():
    """Load vector store (lazy-loaded on first use)"""
    global _vector_store
    if _vector_store is None:
        db_client = chromadb.PersistentClient(path="./MedQuAD_db")
        try:
            _vector_store = db_client.get_collection("medical_rag")
        except:
            # If collection doesn't exist, create it
            _vector_store = db_client.create_collection(name="medical_rag")
    return _vector_store

def get_finetuned_llm():
    """Load fine-tuned model (lazy-loaded on first use)"""
    global _finetuned_llm
    if _finetuned_llm is None:
        ft_model_id = "amiraghhh/fine-tuned-flan-t5-small"
        ft_tokenizer = AutoTokenizer.from_pretrained(ft_model_id)
        ft_model = AutoModelForSeq2SeqLM.from_pretrained(ft_model_id)
        
        _finetuned_llm = pipeline(
            "text2text-generation",
            model=ft_model,
            tokenizer=ft_tokenizer,
            decoder_start_token_id=ft_model.config.pad_token_id
        )
    return _finetuned_llm


# ============================================================================
# MAIN RAG FUNCTION
# ============================================================================

def rag(user_query):
    """Main RAG function: retrieve context and generate answer.
    Takes a question string and returns an answer string with confidence.
    Returns: str(generated_answer)"""
    
    try:
        # Load models on first use
        vector_store = get_vector_store()
        finetuned_llm = get_finetuned_llm()
        
        # 1. Check for emergency keywords
        emergency_keywords = ["emergency", "severe pain", "bleeding",
                            "blind", "lose consciousness", "pass out"]
        
        if any(keyword in user_query.lower() for keyword in emergency_keywords):
            emergency_msg = """I am an AI and cannot provide medical advice for emergencies.
PLEASE contact emergency services or a medical professional immediately."""
            
            try:
                # Still generate answer for context
                contexts = retrieve(vector_store, user_query, top_k=3, use_reranking=True)
                
                if not contexts:
                    return f"{emergency_msg}\n\nNo relevant information found for your query."
                
                prompt = build_prompt(user_query, contexts)
                result = finetuned_llm(
                    prompt,
                    max_new_tokens=70,
                    num_beams=3,
                    early_stopping=True,
                    do_sample=False,
                    repetition_penalty=1.4,
                    eos_token_id=finetuned_llm.tokenizer.eos_token_id
                )
                
                answer = result[0]['generated_text'].strip()
                answer = refine_response(answer)
                
                # Calculate confidence
                if contexts:
                    avg_distance = sum(c.get('chroma_distance', 1.0) for c in contexts) / len(contexts)
                    confidence_score = (1 - avg_distance) * 100
                    confidence_score = max(0, min(100, confidence_score))
                else:
                    confidence_score = 0
                
                return f"{emergency_msg}\n\n[Confidence: {confidence_score:.1f}%]\n\n{answer}"
            
            except Exception as e:
                return f"{emergency_msg}\n\nError generating answer: {str(e)}"
        
        # 2. Retrieve relevant contexts
        contexts = retrieve(vector_store, user_query, top_k=3, use_reranking=True)
        
        if not contexts:
            return "I'm not confident about my answer (0%).\n\nCouldn't find relevant information to answer your question."
        
        # 3. Build prompt with context
        prompt = build_prompt(user_query, contexts)
        
        # 4. Generate answer
        result = finetuned_llm(
            prompt,
            max_new_tokens=70,
            num_beams=3,
            early_stopping=True,
            do_sample=False,
            repetition_penalty=1.4,
            eos_token_id=finetuned_llm.tokenizer.eos_token_id
        )
        
        answer = result[0]['generated_text'].strip()
        answer = refine_response(answer)
        
        # 5. Calculate confidence score based on retrieval quality
        if contexts and len(contexts) > 0:
            avg_distance = sum(c.get('chroma_distance', 1.0) for c in contexts) / len(contexts)
            confidence_score = (1 - avg_distance) * 100
            confidence_score = max(0, min(100, confidence_score))
            
            # Build final response with confidence
            if confidence_score < 40:
                final_response = f"I'm not confident about my answer ({confidence_score:.1f}%).\n\n{answer}"
            else:
                final_response = f"{answer}\n\n[Confidence: {confidence_score:.1f}%]"
        else:
            final_response = "I'm not confident about my answer (0%).\n\n" + answer
        
        return final_response
    
    except Exception as e:
        error_msg = f"ERROR in RAG pipeline: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
        print(error_msg)
        return error_msg