File size: 3,304 Bytes
690bcb6
 
 
 
bfade87
 
690bcb6
bfade87
690bcb6
bfade87
690bcb6
 
 
 
 
 
 
 
 
 
bfade87
690bcb6
 
 
 
 
bfade87
690bcb6
 
 
 
bfade87
690bcb6
bfade87
690bcb6
 
 
 
 
 
 
 
 
bfade87
 
 
 
690bcb6
 
 
 
 
 
 
 
 
 
bfade87
690bcb6
 
 
 
bfade87
690bcb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import streamlit as st
import pickle
import pandas as pd
import numpy as np
import torch
from sentence_transformers import SentenceTransformer, models
from huggingface_hub import hf_hub_download
from src.search import init_faiss

# Hugging Face repo IDs
DATASET_REPO = "param2004/Medilingua-dataset"
MODEL_REPO = "param2004/Medilingua-model"

@st.cache_resource
def load_model():
    """Load SapBERT dynamically from Hugging Face Hub"""
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    st.info(f"🔬 Loading SapBERT from Hugging Face Hub on {device.upper()}...")

    try:
        # Download model files from Hub
        model_path = hf_hub_download(
            repo_id=MODEL_REPO,
            filename="models/SapBERT-from-PubMedBERT-fulltext/pytorch_model.bin"
        )

        # Build SentenceTransformer manually
        word_embedding_model = models.Transformer(model_path)
        pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
        model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device=device)

        st.success("✅ SapBERT loaded successfully from Hugging Face Hub.")
    except Exception as e:
        st.error(f"❌ Failed to load SapBERT from Hub: {e}")
        st.warning("⚠️ Falling back to 'all-MiniLM-L6-v2' model.")
        model = SentenceTransformer('all-MiniLM-L6-v2', device=device)

    return model

@st.cache_resource
def load_data():
    """Load embeddings and dataset dynamically from Hugging Face Hub"""
    try:
        # Download embeddings & CSV from Hub
        question_emb_path = hf_hub_download(DATASET_REPO, filename="dataset/question_embeddings.pkl")
        doctor_emb_path = hf_hub_download(DATASET_REPO, filename="dataset/doctor_embeddings.pkl")
        dataset_csv_path = hf_hub_download(DATASET_REPO, filename="dataset/dataset.csv")

        # Load embeddings
        with open(question_emb_path, 'rb') as f:
            question_data = pickle.load(f)
        question_embeddings = question_data.get('embeddings').astype('float32')

        with open(doctor_emb_path, 'rb') as f:
            doctor_data = pickle.load(f)
        doctor_embeddings = doctor_data.get('embeddings').astype('float32')

        # Load dataset CSV
        df = pd.read_csv(dataset_csv_path)
        df.dropna(subset=['Description', 'Patient', 'Doctor'], inplace=True)
        df.drop_duplicates(inplace=True)

        # Ensure all arrays align
        num_samples = min(len(df), len(question_embeddings), len(doctor_embeddings))
        df = df.iloc[:num_samples]
        question_embeddings = question_embeddings[:num_samples]
        doctor_embeddings = doctor_embeddings[:num_samples]

        st.success(f"✅ Loaded {num_samples} rows with SapBERT embeddings ({question_embeddings.shape[1]} dims)")

        # Initialize FAISS
        init_faiss(question_embeddings)

        return {
            "question_embeddings": question_embeddings,
            "doctor_embeddings": doctor_embeddings,
            "description_column": df["Description"].tolist(),
            "patient_column": df["Patient"].tolist(),
            "original_answers": df["Doctor"].tolist(),
        }

    except Exception as e:
        st.error(f"❌ Error loading dataset or embeddings: {e}")
        return None