| import os |
| import pickle |
| import streamlit as st |
| import torch |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| import warnings |
| import threading |
|
|
| |
| warnings.filterwarnings("ignore", category=UserWarning, module="torch") |
|
|
| |
| |
| |
| MODEL_DIR = "flan-t5-small" |
|
|
| def load_model(): |
| """Load the FLAN-T5 model from a local directory.""" |
| st.write("π Loading FLAN-T5 model from local storage...") |
| try: |
| torch_dtype = torch.float32 if torch.cuda.is_available() else torch.float32 |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, local_files_only=True) |
| model = AutoModelForSeq2SeqLM.from_pretrained( |
| MODEL_DIR, |
| torch_dtype=torch_dtype, |
| local_files_only=True |
| ) |
| st.write("β
Model loaded successfully from local storage!") |
| return tokenizer, model |
| except Exception as e: |
| st.error(f"β Model failed to load: {e}") |
| st.stop() |
|
|
| |
| |
| |
| st.title("π©Ί Healthcare Chatbot (FLAN-T5)") |
|
|
| |
| try: |
| tokenizer, model = load_model() |
| except Exception as e: |
| st.error(f"β Model load error: {e}") |
| st.stop() |
|
|
| |
| |
| |
| medical_qna = [] |
| data_loaded = threading.Event() |
|
|
| def load_pickle_data(): |
| global medical_qna |
| try: |
| with open("train_data_mod_obfuscated_fixed.pkl", "rb") as file: |
| medical_qna = pickle.load(file) |
| data_loaded.set() |
| except Exception as e: |
| medical_qna = [] |
| data_loaded.set() |
|
|
| |
| threading.Thread(target=load_pickle_data, daemon=True).start() |
|
|
| |
| data_load_placeholder = st.empty() |
| if not data_loaded.is_set(): |
| data_load_placeholder.info("π Loading medical Q&A data...") |
|
|
| |
| if data_loaded.wait(timeout=5): |
| data_load_placeholder.success("β
Q&A data loaded successfully!") |
| else: |
| data_load_placeholder.warning("β οΈ Q&A data is still loading. Chatbot responses may be delayed.") |
|
|
| |
| |
| |
| def chatbot_response(user_input: str) -> str: |
| if not data_loaded.is_set() or not medical_qna: |
| return "The Q&A data is still loading. Please try again shortly." |
|
|
| for qa in medical_qna: |
| if user_input.lower() in qa["question"].lower(): |
| return qa["answer"] |
|
|
| prompt = ( |
| "You are a helpful medical assistant. The user asked:\n" |
| f"Question: {user_input}\n\n" |
| "Answer in a concise, accurate way. If you're unsure, advise seeing a doctor." |
| ) |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True) |
| outputs = model.generate( |
| **inputs, |
| max_length=256, |
| num_beams=2, |
| no_repeat_ngram_size=2 |
| ) |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| |
| |
| |
| if st.button("What can you help me with?"): |
| st.write("I can provide general information about medical symptoms, treatments, and offer guidance. If you have serious concerns, please contact a doctor.") |
|
|
| user_input = st.text_input("Ask me a medical question:") |
| if st.button("Get Answer"): |
| if user_input.strip(): |
| response = chatbot_response(user_input) |
| st.write(f"**Bot:** {response}") |
| else: |
| st.warning("Please enter a question.") |
|
|