|
|
import streamlit as st
|
|
|
import torch
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
|
|
from peft import PeftModel
|
|
|
import os
|
|
|
|
|
|
|
|
|
st.set_page_config(
|
|
|
page_title="TinyLlama Medical Assistant",
|
|
|
page_icon="π©Ί",
|
|
|
layout="wide"
|
|
|
)
|
|
|
|
|
|
|
|
|
st.markdown("""
|
|
|
<style>
|
|
|
.main {background-color: #f0f2f6;}
|
|
|
.stButton>button {
|
|
|
width: 100%;
|
|
|
background-color: #4F46E5;
|
|
|
color: white;
|
|
|
}
|
|
|
.chat-message {
|
|
|
padding: 1.5rem;
|
|
|
border-radius: 0.5rem;
|
|
|
margin-bottom: 1rem;
|
|
|
display: flex;
|
|
|
flex-direction: column;
|
|
|
}
|
|
|
.chat-message.user {
|
|
|
background-color: #4F46E5;
|
|
|
color: white;
|
|
|
}
|
|
|
.chat-message.assistant {
|
|
|
background-color: white;
|
|
|
border: 1px solid #e5e7eb;
|
|
|
}
|
|
|
</style>
|
|
|
""", unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
USERS = {
|
|
|
"admin": "admin123",
|
|
|
"doctor": "doc123",
|
|
|
"student": "student123"
|
|
|
}
|
|
|
|
|
|
MEDICAL_DISCLAIMER = """
|
|
|
β οΈ **Medical Disclaimer:** This response is for educational purposes only and is not a substitute for professional medical advice. Always consult a qualified healthcare provider.
|
|
|
"""
|
|
|
|
|
|
|
|
|
if "authenticated" not in st.session_state:
|
|
|
st.session_state.authenticated = False
|
|
|
if "messages" not in st.session_state:
|
|
|
st.session_state.messages = []
|
|
|
if "model_loaded" not in st.session_state:
|
|
|
st.session_state.model_loaded = False
|
|
|
|
|
|
|
|
|
if not st.session_state.authenticated:
|
|
|
col1, col2, col3 = st.columns([1, 2, 1])
|
|
|
|
|
|
with col2:
|
|
|
st.title("π Medical Assistant Login")
|
|
|
st.markdown("---")
|
|
|
|
|
|
username = st.text_input("Username", key="login_username")
|
|
|
password = st.text_input("Password", type="password", key="login_password")
|
|
|
|
|
|
col_a, col_b = st.columns(2)
|
|
|
|
|
|
with col_a:
|
|
|
if st.button("Login", use_container_width=True):
|
|
|
if username in USERS and USERS[username] == password:
|
|
|
st.session_state.authenticated = True
|
|
|
st.session_state.username = username
|
|
|
st.success("β
Login successful!")
|
|
|
st.rerun()
|
|
|
else:
|
|
|
st.error("β Invalid credentials")
|
|
|
|
|
|
with col_b:
|
|
|
if st.button("Clear", use_container_width=True):
|
|
|
st.rerun()
|
|
|
|
|
|
st.markdown("---")
|
|
|
st.info("""
|
|
|
**Demo Credentials:**
|
|
|
- admin / admin123
|
|
|
- doctor / doc123
|
|
|
- student / student123
|
|
|
""")
|
|
|
|
|
|
st.stop()
|
|
|
|
|
|
|
|
|
@st.cache_resource(show_spinner=False)
|
|
|
def load_model():
|
|
|
"""Load the fine-tuned TinyLlama model with LoRA adapters"""
|
|
|
try:
|
|
|
base_model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
|
|
lora_path = "./tinyllama-medical-lora"
|
|
|
|
|
|
|
|
|
if not os.path.exists(lora_path):
|
|
|
st.error(f"β Model not found at {lora_path}")
|
|
|
st.info("Using base model without fine-tuning...")
|
|
|
lora_path = None
|
|
|
|
|
|
|
|
|
bnb_config = BitsAndBytesConfig(
|
|
|
load_in_4bit=True,
|
|
|
bnb_4bit_quant_type="nf4",
|
|
|
bnb_4bit_compute_dtype=torch.bfloat16,
|
|
|
bnb_4bit_use_double_quant=True
|
|
|
)
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
base_model_name,
|
|
|
quantization_config=bnb_config,
|
|
|
device_map="auto",
|
|
|
trust_remote_code=True
|
|
|
)
|
|
|
|
|
|
|
|
|
if lora_path:
|
|
|
model = PeftModel.from_pretrained(model, lora_path)
|
|
|
st.success("β
Fine-tuned model loaded successfully!")
|
|
|
else:
|
|
|
st.warning("β οΈ Using base model (not fine-tuned)")
|
|
|
|
|
|
model.eval()
|
|
|
|
|
|
return tokenizer, model
|
|
|
|
|
|
except Exception as e:
|
|
|
st.error(f"Error loading model: {str(e)}")
|
|
|
return None, None
|
|
|
|
|
|
|
|
|
st.title("π©Ί TinyLlama Medical Assistant")
|
|
|
st.caption(f"Logged in as: **{st.session_state.username}**")
|
|
|
|
|
|
|
|
|
with st.sidebar:
|
|
|
st.header("βοΈ Settings")
|
|
|
|
|
|
|
|
|
if not st.session_state.model_loaded:
|
|
|
with st.spinner("Loading fine-tuned model..."):
|
|
|
tokenizer, model = load_model()
|
|
|
if tokenizer and model:
|
|
|
st.session_state.tokenizer = tokenizer
|
|
|
st.session_state.model = model
|
|
|
st.session_state.model_loaded = True
|
|
|
|
|
|
st.markdown("---")
|
|
|
|
|
|
|
|
|
st.subheader("Generation Parameters")
|
|
|
temperature = st.slider("Temperature", 0.1, 1.5, 0.7, 0.1)
|
|
|
max_tokens = st.slider("Max New Tokens", 32, 256, 100, 8)
|
|
|
top_p = st.slider("Top-p", 0.1, 1.0, 0.9, 0.05)
|
|
|
|
|
|
st.markdown("---")
|
|
|
|
|
|
|
|
|
st.subheader("π‘ Example Queries")
|
|
|
example_queries = [
|
|
|
"What is Paracetamol used for?",
|
|
|
"Tell me about Ibuprofen",
|
|
|
"What is Metformin?",
|
|
|
"Uses of Amoxicillin",
|
|
|
"What is Atorvastatin for?"
|
|
|
]
|
|
|
|
|
|
for query in example_queries:
|
|
|
if st.button(query, key=f"example_{query}", use_container_width=True):
|
|
|
st.session_state.messages.append({"role": "user", "content": query})
|
|
|
st.rerun()
|
|
|
|
|
|
st.markdown("---")
|
|
|
|
|
|
|
|
|
if st.button("ποΈ Clear Chat", use_container_width=True):
|
|
|
st.session_state.messages = []
|
|
|
st.rerun()
|
|
|
|
|
|
|
|
|
if st.button("πͺ Logout", use_container_width=True):
|
|
|
st.session_state.authenticated = False
|
|
|
st.session_state.messages = []
|
|
|
st.rerun()
|
|
|
|
|
|
|
|
|
for message in st.session_state.messages:
|
|
|
with st.chat_message(message["role"]):
|
|
|
st.markdown(message["content"])
|
|
|
|
|
|
|
|
|
if prompt := st.chat_input("Ask a medical question..."):
|
|
|
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt})
|
|
|
with st.chat_message("user"):
|
|
|
st.markdown(prompt)
|
|
|
|
|
|
|
|
|
with st.chat_message("assistant"):
|
|
|
with st.spinner("Thinking..."):
|
|
|
if st.session_state.model_loaded:
|
|
|
try:
|
|
|
|
|
|
formatted_prompt = f"""### Instruction:
|
|
|
{prompt}
|
|
|
|
|
|
### Response:
|
|
|
"""
|
|
|
|
|
|
|
|
|
inputs = st.session_state.tokenizer(
|
|
|
formatted_prompt,
|
|
|
return_tensors="pt"
|
|
|
).to(st.session_state.model.device)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = st.session_state.model.generate(
|
|
|
**inputs,
|
|
|
max_new_tokens=max_tokens,
|
|
|
temperature=temperature,
|
|
|
top_p=top_p,
|
|
|
do_sample=True,
|
|
|
pad_token_id=st.session_state.tokenizer.eos_token_id
|
|
|
)
|
|
|
|
|
|
|
|
|
response = st.session_state.tokenizer.decode(
|
|
|
outputs[0],
|
|
|
skip_special_tokens=True
|
|
|
)
|
|
|
|
|
|
|
|
|
if "### Response:" in response:
|
|
|
response = response.split("### Response:")[-1].strip()
|
|
|
|
|
|
|
|
|
full_response = f"{response}\n\n{MEDICAL_DISCLAIMER}"
|
|
|
|
|
|
st.markdown(full_response)
|
|
|
st.session_state.messages.append({
|
|
|
"role": "assistant",
|
|
|
"content": full_response
|
|
|
})
|
|
|
|
|
|
except Exception as e:
|
|
|
error_msg = f"Error generating response: {str(e)}"
|
|
|
st.error(error_msg)
|
|
|
st.session_state.messages.append({
|
|
|
"role": "assistant",
|
|
|
"content": error_msg
|
|
|
})
|
|
|
else:
|
|
|
error_msg = "Model not loaded. Please refresh the page."
|
|
|
st.error(error_msg)
|
|
|
st.session_state.messages.append({
|
|
|
"role": "assistant",
|
|
|
"content": error_msg
|
|
|
})
|
|
|
|
|
|
|
|
|
st.markdown("---")
|
|
|
st.caption("Fine-tuned TinyLlama 1.1B with LoRA on Allopathic Medicine Dataset") |