HEA_Query / app.py
taradutt007's picture
Update app.py
f3b354e verified
import gradio as gr
import pandas as pd
import re
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
# --- Hugging Face token for gated models ---
HF_TOKEN = os.environ["HF_API_TOKEN"]
# --- Paths ---
CSV_FOLDER = "data"
FAISS_INDEX_PATH = "data/faiss_index_hugface_BAAI_new"
# --- Load CSVs ---
d1 = pd.read_csv(f"{CSV_FOLDER}/dataset1_clean.csv")
d2 = pd.read_csv(f"{CSV_FOLDER}/dataset2_clean.csv")
d3 = pd.read_csv(f"{CSV_FOLDER}/dataset3_clean.csv")
print("✅ CSVs loaded")
# --- Load FAISS ---
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en")
faiss_index = FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True)
print("✅ FAISS loaded")
# --- Load Mistral model (CPU-friendly) ---
MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)
print("✅ Mistral model loaded on CPU")
# --- Property synonyms ---
property_synonyms = {
"hardness": ["hv", "hardness", "vicker's hardness", "vickers hardness"],
"bulk modulus": ["d_bulk (gpa)", "bulk modulus", "bulk_modulus"],
"yield_strength": ["ys", "yield stress", "yield strength"],
"ultimate_strength": ["uts", "tensile strength", "ultimate tensile strength"],
"phase": ["phase_label", "bcc/fcc/other", "phase", "microstructure"],
"density": ["density_exp", "density_calc", "density"]
}
# --- Helper functions ---
def find_column_for_property(df, property_name):
synonyms = property_synonyms.get(property_name.lower(), [property_name])
for syn in synonyms:
for col in df.columns:
if syn.lower() in col.lower():
return col
return None
def parse_query_to_filters(question):
query = question.lower()
filters = {}
for phase in ["fcc", "bcc", "hcp", "other"]:
if phase in query:
filters["phase"] = phase
break
numeric_props = ["hardness", "bulk modulus", "yield strength", "ultimate strength", "density"]
for prop in numeric_props:
pattern = rf"{prop}\s*(>=|<=|=|>|<)\s*(\d+\.?\d*)"
match = re.search(pattern, query)
if match:
op, val = match.groups()
filters[prop] = f"{op}{val}"
if f"highest {prop}" in query or f"high {prop}" in query:
filters[prop] = "high"
elif f"lowest {prop}" in query or f"low {prop}" in query:
filters[prop] = "low"
return filters
def apply_numeric_filter(df, col, filter_value):
if filter_value == "high":
return df.sort_values(by=col, ascending=False)
elif filter_value == "low":
return df.sort_values(by=col, ascending=True)
else:
match = re.match(r"(>=|<=|=|>|<)(\d+\.?\d*)", filter_value)
if not match:
return df
op, val_str = match.groups()
val = float(val_str)
if op == ">": return df[df[col] > val]
if op == "<": return df[df[col] < val]
if op == ">=": return df[df[col] >= val]
if op == "<=": return df[df[col] <= val]
if op == "=": return df[df[col] == val]
return df
def filter_all_datasets(datasets, queries, top_n=10):
results = {}
for df, name in datasets:
df_filtered = df.copy()
phase_filter = queries.get("phase", None)
if phase_filter:
phase_col = None
for col in df_filtered.columns:
if any(phase_key in col.lower() for phase_key in property_synonyms["phase"]):
phase_col = col
break
if phase_col:
df_filtered = df_filtered[df_filtered[phase_col].str.contains(phase_filter, case=False, na=False)]
else:
continue
for prop, filter_val in queries.items():
if prop == "phase":
continue
col = find_column_for_property(df_filtered, prop)
if col is None:
df_filtered = None
break
df_filtered = df_filtered[df_filtered[col].notna()]
df_filtered = apply_numeric_filter(df_filtered, col, filter_val)
if df_filtered is None or df_filtered.empty:
continue
show_cols = []
if "formula" in df_filtered.columns:
show_cols.append("formula")
for prop in queries.keys():
if prop == "phase":
continue
col = find_column_for_property(df_filtered, prop)
if col and col in df_filtered.columns:
show_cols.append(col)
if phase_filter and phase_col:
show_cols.append(phase_col)
show_cols = [c for c in show_cols if c in df_filtered.columns]
df_filtered = df_filtered[show_cols].head(top_n).copy()
df_filtered["Source"] = name
results[name] = df_filtered
return results
# --- Main HEA query function ---
def query_hea(question, top_k=5):
# FAISS retrieval
faiss_results = faiss_index.similarity_search(question, k=top_k)
faiss_text = "\n".join([doc.page_content for doc in faiss_results])
# CSV filtering
queries = parse_query_to_filters(question)
csv_results_dict = filter_all_datasets(
[(d1, "MPEA"), (d2, "MLPred"), (d3, "Achief")],
queries,
top_n=top_k
)
csv_context = ""
for name, df_filtered in csv_results_dict.items():
csv_context += f"\n### {name} matches:\n{df_filtered.to_string(index=False)}\n"
# Prompt for Mistral
prompt = f"""
You are a materials scientist. Based on the following context, answer precisely.
FAISS context: {faiss_text}
CSV datasets context: {csv_context}
Question: {question}
Answer:
"""
# Tokenize and generate
inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.0)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Merge CSV results
merged_df = pd.concat(csv_results_dict.values(), ignore_index=True) if csv_results_dict else pd.DataFrame()
return answer, merged_df, faiss_text
# --- Gradio wrapper ---
def gradio_query(question):
return query_hea(question)
# --- Launch Gradio interface ---
demo = gr.Interface(
fn=gradio_query,
inputs=gr.Textbox(lines=2, placeholder="Ask about HEAs..."),
outputs=[
gr.Textbox(label="LLM Answer"),
gr.Dataframe(label="CSV Matches"),
gr.Textbox(label="Paper Context (FAISS)")
],
title="🔬 HEA Query",
description="Query HEA datasets + FAISS paper embeddings"
)
demo.launch()