Spaces:
Runtime error
Runtime error
File size: 6,824 Bytes
1a53fb9 dcaf215 eb2b7d5 f3b354e 1a53fb9 dcaf215 1a53fb9 dcaf215 1a53fb9 eb2b7d5 1a53fb9 eb2b7d5 dcaf215 eb2b7d5 56c08fb dcaf215 1a53fb9 dcaf215 1a53fb9 dcaf215 1a53fb9 dcaf215 1a53fb9 dcaf215 1a53fb9 dcaf215 1a53fb9 dcaf215 1a53fb9 dcaf215 1a53fb9 8d96f02 1a53fb9 dcaf215 8d96f02 1a53fb9 8d96f02 1a53fb9 8d96f02 eb2b7d5 1a53fb9 8d96f02 dcaf215 eb2b7d5 dcaf215 1a53fb9 dcaf215 1a53fb9 dcaf215 1a53fb9 dcaf215 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 | import gradio as gr
import pandas as pd
import re
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
# --- Hugging Face token for gated models ---
HF_TOKEN = os.environ["HF_API_TOKEN"]
# --- Paths ---
CSV_FOLDER = "data"
FAISS_INDEX_PATH = "data/faiss_index_hugface_BAAI_new"
# --- Load CSVs ---
d1 = pd.read_csv(f"{CSV_FOLDER}/dataset1_clean.csv")
d2 = pd.read_csv(f"{CSV_FOLDER}/dataset2_clean.csv")
d3 = pd.read_csv(f"{CSV_FOLDER}/dataset3_clean.csv")
print("✅ CSVs loaded")
# --- Load FAISS ---
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en")
faiss_index = FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True)
print("✅ FAISS loaded")
# --- Load Mistral model (CPU-friendly) ---
MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)
print("✅ Mistral model loaded on CPU")
# --- Property synonyms ---
property_synonyms = {
"hardness": ["hv", "hardness", "vicker's hardness", "vickers hardness"],
"bulk modulus": ["d_bulk (gpa)", "bulk modulus", "bulk_modulus"],
"yield_strength": ["ys", "yield stress", "yield strength"],
"ultimate_strength": ["uts", "tensile strength", "ultimate tensile strength"],
"phase": ["phase_label", "bcc/fcc/other", "phase", "microstructure"],
"density": ["density_exp", "density_calc", "density"]
}
# --- Helper functions ---
def find_column_for_property(df, property_name):
synonyms = property_synonyms.get(property_name.lower(), [property_name])
for syn in synonyms:
for col in df.columns:
if syn.lower() in col.lower():
return col
return None
def parse_query_to_filters(question):
query = question.lower()
filters = {}
for phase in ["fcc", "bcc", "hcp", "other"]:
if phase in query:
filters["phase"] = phase
break
numeric_props = ["hardness", "bulk modulus", "yield strength", "ultimate strength", "density"]
for prop in numeric_props:
pattern = rf"{prop}\s*(>=|<=|=|>|<)\s*(\d+\.?\d*)"
match = re.search(pattern, query)
if match:
op, val = match.groups()
filters[prop] = f"{op}{val}"
if f"highest {prop}" in query or f"high {prop}" in query:
filters[prop] = "high"
elif f"lowest {prop}" in query or f"low {prop}" in query:
filters[prop] = "low"
return filters
def apply_numeric_filter(df, col, filter_value):
if filter_value == "high":
return df.sort_values(by=col, ascending=False)
elif filter_value == "low":
return df.sort_values(by=col, ascending=True)
else:
match = re.match(r"(>=|<=|=|>|<)(\d+\.?\d*)", filter_value)
if not match:
return df
op, val_str = match.groups()
val = float(val_str)
if op == ">": return df[df[col] > val]
if op == "<": return df[df[col] < val]
if op == ">=": return df[df[col] >= val]
if op == "<=": return df[df[col] <= val]
if op == "=": return df[df[col] == val]
return df
def filter_all_datasets(datasets, queries, top_n=10):
results = {}
for df, name in datasets:
df_filtered = df.copy()
phase_filter = queries.get("phase", None)
if phase_filter:
phase_col = None
for col in df_filtered.columns:
if any(phase_key in col.lower() for phase_key in property_synonyms["phase"]):
phase_col = col
break
if phase_col:
df_filtered = df_filtered[df_filtered[phase_col].str.contains(phase_filter, case=False, na=False)]
else:
continue
for prop, filter_val in queries.items():
if prop == "phase":
continue
col = find_column_for_property(df_filtered, prop)
if col is None:
df_filtered = None
break
df_filtered = df_filtered[df_filtered[col].notna()]
df_filtered = apply_numeric_filter(df_filtered, col, filter_val)
if df_filtered is None or df_filtered.empty:
continue
show_cols = []
if "formula" in df_filtered.columns:
show_cols.append("formula")
for prop in queries.keys():
if prop == "phase":
continue
col = find_column_for_property(df_filtered, prop)
if col and col in df_filtered.columns:
show_cols.append(col)
if phase_filter and phase_col:
show_cols.append(phase_col)
show_cols = [c for c in show_cols if c in df_filtered.columns]
df_filtered = df_filtered[show_cols].head(top_n).copy()
df_filtered["Source"] = name
results[name] = df_filtered
return results
# --- Main HEA query function ---
def query_hea(question, top_k=5):
# FAISS retrieval
faiss_results = faiss_index.similarity_search(question, k=top_k)
faiss_text = "\n".join([doc.page_content for doc in faiss_results])
# CSV filtering
queries = parse_query_to_filters(question)
csv_results_dict = filter_all_datasets(
[(d1, "MPEA"), (d2, "MLPred"), (d3, "Achief")],
queries,
top_n=top_k
)
csv_context = ""
for name, df_filtered in csv_results_dict.items():
csv_context += f"\n### {name} matches:\n{df_filtered.to_string(index=False)}\n"
# Prompt for Mistral
prompt = f"""
You are a materials scientist. Based on the following context, answer precisely.
FAISS context: {faiss_text}
CSV datasets context: {csv_context}
Question: {question}
Answer:
"""
# Tokenize and generate
inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.0)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Merge CSV results
merged_df = pd.concat(csv_results_dict.values(), ignore_index=True) if csv_results_dict else pd.DataFrame()
return answer, merged_df, faiss_text
# --- Gradio wrapper ---
def gradio_query(question):
return query_hea(question)
# --- Launch Gradio interface ---
demo = gr.Interface(
fn=gradio_query,
inputs=gr.Textbox(lines=2, placeholder="Ask about HEAs..."),
outputs=[
gr.Textbox(label="LLM Answer"),
gr.Dataframe(label="CSV Matches"),
gr.Textbox(label="Paper Context (FAISS)")
],
title="🔬 HEA Query",
description="Query HEA datasets + FAISS paper embeddings"
)
demo.launch()
|