File size: 4,245 Bytes
b1a4d14
 
 
0979c15
b1a4d14
 
 
 
 
 
 
de03d0d
b1a4d14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbce739
574c0fa
de03d0d
cbce739
b1a4d14
 
 
 
 
 
de03d0d
e12e0d3
b1a4d14
de03d0d
b1a4d14
 
bb220ac
e12e0d3
4da3949
 
574c0fa
e12e0d3
0979c15
 
b1a4d14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0979c15
e12e0d3
b1a4d14
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
File: app.py
Usage: Hugging Face Spaces Deployment
Description: Domain-Specific Assistant via LLMs Fine-Tuning.
"""

import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Hugging Face Hub repository containing the fine-tuned model
MODEL_REPO = "degide/tinyllama-medical-assistant"

print("Downloading and loading the fine-tuned medical chatbot...")

# 1. Load the Tokenizer and Model directly from Hugging Face Hub
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, trust_remote_code=True)

# Configuring the model for efficient CPU loading.
model = AutoModelForCausalLM.from_pretrained(
    MODEL_REPO,
    device_map="cpu",
    trust_remote_code=True,
    torch_dtype=torch.float32,
    low_cpu_mem_usage=True,
)
model.eval()

print("Model loaded successfully!")

def detect_ood(query):
    """Heuristic-based Out-Of-Domain (OOD) detection."""
    medical_keywords = [
        'symptom', 'disease', 'treatment', 'medicine', 'doctor', 'health',
        'diabetes', 'blood', 'pressure', 'heart', 'pain', 'sick', 'hospital',
        'care', 'diagnosis', 'patient', 'clinic', 'drug', 'therapy', 'cancer',
        'syndrome', 'infection', 'virus', 'bacteria', 'pill', 'dosage'
    ]
    
    query_lower = query.lower()
    has_medical = any(kw in query_lower for kw in medical_keywords)
    
    non_medical_patterns = [
        'cook', 'recipe', 'weather', 'capital', 'python', 'code', 
        'movie', 'song', 'game', 'sports', 'programming', 'math'
    ]
    is_non_medical = any(pattern in query_lower for pattern in non_medical_patterns)
    
    return is_non_medical or not has_medical

def generate_medical_response(message, history):
    """Generates the chatbot response with OOD handling."""
    
    if detect_ood(message):
        return (
            "**Out of Domain Detected:** I apologize, but I am a specialized medical "
            "assistant and can only answer health-related questions. Could you please "
            "ask me about medical symptoms, conditions, or treatments?\n\n"
            "*Examples:*\n"
            "- What are the symptoms of asthma?\n"
            "- How is high blood pressure diagnosed?"
        )
    
    prompt = (
        f"</s><|system|>You are a highly accurate and helpful medical assistant.</s>"
        f"<|user|>{message}</s>"
    )
    inputs = tokenizer(prompt, return_tensors="pt")
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=256,
            temperature=0.3,
            repetition_penalty=1.0,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    print(f"generated_text: {generated_text}")

    final_answer = generated_text.split("<|assistant|>")[-1].strip()

    if not final_answer:
        final_answer = "I apologize, but I am unable to generate a confident medical response to that exact phrasing. Could you please rephrase your question?"
    
    disclaimer = (
        "\n\n---\n"
        "**Medical Disclaimer:** *This chatbot provides general health information "
        "only based on fine-tuned data. It is not a replacement for professional "
        "medical advice. Always consult a qualified healthcare provider.*"
    )
    
    return final_answer + disclaimer

# --- USER INTERFACE ---
demo = gr.ChatInterface(
    fn=generate_medical_response,
    title="Domain-Specific Medical Assistant (TinyLlama)",
    description=(
        "An LLM fine-tuned via LoRA on the Medical Meadow Flashcards dataset. "
        "Ask questions about medical symptoms, conditions, and treatments."
    ),
    examples=[
        "What are the common symptoms of type 2 diabetes?",
        "Explain the mechanism of action of metformin.",
        "What is the prognosis for patients with stage 3 chronic kidney disease?",
        "Describe the side effects of chemotherapy for breast cancer."
    ],
    chatbot=gr.Chatbot(height=600),
    save_history=True,
    fill_height=True,
    fill_width=True,
    submit_btn="Ask",
    stop_btn="Stop"
)

if __name__ == "__main__":
    demo.launch(
        share=False,
        server_port=7860,
    )