File size: 6,824 Bytes
1a53fb9
 
 
 
dcaf215
 
eb2b7d5
 
 
f3b354e
1a53fb9
dcaf215
1a53fb9
 
 
dcaf215
1a53fb9
 
 
 
 
eb2b7d5
1a53fb9
 
 
 
eb2b7d5
dcaf215
eb2b7d5
 
 
56c08fb
dcaf215
1a53fb9
 
 
 
 
 
 
 
 
dcaf215
1a53fb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcaf215
1a53fb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcaf215
 
 
1a53fb9
 
 
 
dcaf215
1a53fb9
dcaf215
1a53fb9
dcaf215
1a53fb9
 
 
 
 
dcaf215
1a53fb9
8d96f02
1a53fb9
 
dcaf215
8d96f02
1a53fb9
 
 
8d96f02
 
1a53fb9
 
 
 
8d96f02
eb2b7d5
1a53fb9
8d96f02
 
 
 
 
 
 
dcaf215
eb2b7d5
dcaf215
 
 
 
1a53fb9
 
dcaf215
 
 
1a53fb9
 
 
dcaf215
1a53fb9
 
 
 
 
 
 
 
 
 
 
 
dcaf215
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import gradio as gr
import pandas as pd
import re
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from transformers import AutoTokenizer, AutoModelForCausalLM
import os

# --- Hugging Face token for gated models ---
HF_TOKEN = os.environ["HF_API_TOKEN"]

# --- Paths ---
CSV_FOLDER = "data"
FAISS_INDEX_PATH = "data/faiss_index_hugface_BAAI_new"

# --- Load CSVs ---
d1 = pd.read_csv(f"{CSV_FOLDER}/dataset1_clean.csv")
d2 = pd.read_csv(f"{CSV_FOLDER}/dataset2_clean.csv")
d3 = pd.read_csv(f"{CSV_FOLDER}/dataset3_clean.csv")
print("✅ CSVs loaded")

# --- Load FAISS ---
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en")
faiss_index = FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True)
print("✅ FAISS loaded")

# --- Load Mistral model (CPU-friendly) ---
MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)
print("✅ Mistral model loaded on CPU")

# --- Property synonyms ---
property_synonyms = {
    "hardness": ["hv", "hardness", "vicker's hardness", "vickers hardness"],
    "bulk modulus": ["d_bulk (gpa)", "bulk modulus", "bulk_modulus"],
    "yield_strength": ["ys", "yield stress", "yield strength"],
    "ultimate_strength": ["uts", "tensile strength", "ultimate tensile strength"],
    "phase": ["phase_label", "bcc/fcc/other", "phase", "microstructure"],
    "density": ["density_exp", "density_calc", "density"]
}

# --- Helper functions ---
def find_column_for_property(df, property_name):
    synonyms = property_synonyms.get(property_name.lower(), [property_name])
    for syn in synonyms:
        for col in df.columns:
            if syn.lower() in col.lower():
                return col
    return None

def parse_query_to_filters(question):
    query = question.lower()
    filters = {}
    for phase in ["fcc", "bcc", "hcp", "other"]:
        if phase in query:
            filters["phase"] = phase
            break
    numeric_props = ["hardness", "bulk modulus", "yield strength", "ultimate strength", "density"]
    for prop in numeric_props:
        pattern = rf"{prop}\s*(>=|<=|=|>|<)\s*(\d+\.?\d*)"
        match = re.search(pattern, query)
        if match:
            op, val = match.groups()
            filters[prop] = f"{op}{val}"
        if f"highest {prop}" in query or f"high {prop}" in query:
            filters[prop] = "high"
        elif f"lowest {prop}" in query or f"low {prop}" in query:
            filters[prop] = "low"
    return filters

def apply_numeric_filter(df, col, filter_value):
    if filter_value == "high":
        return df.sort_values(by=col, ascending=False)
    elif filter_value == "low":
        return df.sort_values(by=col, ascending=True)
    else:
        match = re.match(r"(>=|<=|=|>|<)(\d+\.?\d*)", filter_value)
        if not match:
            return df
        op, val_str = match.groups()
        val = float(val_str)
        if op == ">": return df[df[col] > val]
        if op == "<": return df[df[col] < val]
        if op == ">=": return df[df[col] >= val]
        if op == "<=": return df[df[col] <= val]
        if op == "=": return df[df[col] == val]
        return df

def filter_all_datasets(datasets, queries, top_n=10):
    results = {}
    for df, name in datasets:
        df_filtered = df.copy()
        phase_filter = queries.get("phase", None)
        if phase_filter:
            phase_col = None
            for col in df_filtered.columns:
                if any(phase_key in col.lower() for phase_key in property_synonyms["phase"]):
                    phase_col = col
                    break
            if phase_col:
                df_filtered = df_filtered[df_filtered[phase_col].str.contains(phase_filter, case=False, na=False)]
            else:
                continue
        for prop, filter_val in queries.items():
            if prop == "phase":
                continue
            col = find_column_for_property(df_filtered, prop)
            if col is None:
                df_filtered = None
                break
            df_filtered = df_filtered[df_filtered[col].notna()]
            df_filtered = apply_numeric_filter(df_filtered, col, filter_val)
        if df_filtered is None or df_filtered.empty:
            continue
        show_cols = []
        if "formula" in df_filtered.columns:
            show_cols.append("formula")
        for prop in queries.keys():
            if prop == "phase":
                continue
            col = find_column_for_property(df_filtered, prop)
            if col and col in df_filtered.columns:
                show_cols.append(col)
        if phase_filter and phase_col:
            show_cols.append(phase_col)
        show_cols = [c for c in show_cols if c in df_filtered.columns]
        df_filtered = df_filtered[show_cols].head(top_n).copy()
        df_filtered["Source"] = name
        results[name] = df_filtered
    return results

# --- Main HEA query function ---
def query_hea(question, top_k=5):
    # FAISS retrieval
    faiss_results = faiss_index.similarity_search(question, k=top_k)
    faiss_text = "\n".join([doc.page_content for doc in faiss_results])

    # CSV filtering
    queries = parse_query_to_filters(question)
    csv_results_dict = filter_all_datasets(
        [(d1, "MPEA"), (d2, "MLPred"), (d3, "Achief")],
        queries,
        top_n=top_k
    )
    csv_context = ""
    for name, df_filtered in csv_results_dict.items():
        csv_context += f"\n### {name} matches:\n{df_filtered.to_string(index=False)}\n"

    # Prompt for Mistral
    prompt = f"""
You are a materials scientist. Based on the following context, answer precisely.
FAISS context: {faiss_text}
CSV datasets context: {csv_context}
Question: {question}
Answer:
"""

    # Tokenize and generate
    inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
    outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.0)
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Merge CSV results
    merged_df = pd.concat(csv_results_dict.values(), ignore_index=True) if csv_results_dict else pd.DataFrame()

    return answer, merged_df, faiss_text

# --- Gradio wrapper ---
def gradio_query(question):
    return query_hea(question)

# --- Launch Gradio interface ---
demo = gr.Interface(
    fn=gradio_query,
    inputs=gr.Textbox(lines=2, placeholder="Ask about HEAs..."),
    outputs=[
        gr.Textbox(label="LLM Answer"),
        gr.Dataframe(label="CSV Matches"),
        gr.Textbox(label="Paper Context (FAISS)")
    ],
    title="🔬 HEA Query",
    description="Query HEA datasets + FAISS paper embeddings"
)

demo.launch()