import gradio as gr import pandas as pd import re from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from transformers import AutoTokenizer, AutoModelForCausalLM import os # --- Hugging Face token for gated models --- HF_TOKEN = os.environ["HF_API_TOKEN"] # --- Paths --- CSV_FOLDER = "data" FAISS_INDEX_PATH = "data/faiss_index_hugface_BAAI_new" # --- Load CSVs --- d1 = pd.read_csv(f"{CSV_FOLDER}/dataset1_clean.csv") d2 = pd.read_csv(f"{CSV_FOLDER}/dataset2_clean.csv") d3 = pd.read_csv(f"{CSV_FOLDER}/dataset3_clean.csv") print("✅ CSVs loaded") # --- Load FAISS --- embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en") faiss_index = FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True) print("✅ FAISS loaded") # --- Load Mistral model (CPU-friendly) --- MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.3" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN) model = AutoModelForCausalLM.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN) print("✅ Mistral model loaded on CPU") # --- Property synonyms --- property_synonyms = { "hardness": ["hv", "hardness", "vicker's hardness", "vickers hardness"], "bulk modulus": ["d_bulk (gpa)", "bulk modulus", "bulk_modulus"], "yield_strength": ["ys", "yield stress", "yield strength"], "ultimate_strength": ["uts", "tensile strength", "ultimate tensile strength"], "phase": ["phase_label", "bcc/fcc/other", "phase", "microstructure"], "density": ["density_exp", "density_calc", "density"] } # --- Helper functions --- def find_column_for_property(df, property_name): synonyms = property_synonyms.get(property_name.lower(), [property_name]) for syn in synonyms: for col in df.columns: if syn.lower() in col.lower(): return col return None def parse_query_to_filters(question): query = question.lower() filters = {} for phase in ["fcc", "bcc", "hcp", "other"]: if phase in query: filters["phase"] = phase break numeric_props = ["hardness", "bulk modulus", "yield strength", "ultimate strength", "density"] for prop in numeric_props: pattern = rf"{prop}\s*(>=|<=|=|>|<)\s*(\d+\.?\d*)" match = re.search(pattern, query) if match: op, val = match.groups() filters[prop] = f"{op}{val}" if f"highest {prop}" in query or f"high {prop}" in query: filters[prop] = "high" elif f"lowest {prop}" in query or f"low {prop}" in query: filters[prop] = "low" return filters def apply_numeric_filter(df, col, filter_value): if filter_value == "high": return df.sort_values(by=col, ascending=False) elif filter_value == "low": return df.sort_values(by=col, ascending=True) else: match = re.match(r"(>=|<=|=|>|<)(\d+\.?\d*)", filter_value) if not match: return df op, val_str = match.groups() val = float(val_str) if op == ">": return df[df[col] > val] if op == "<": return df[df[col] < val] if op == ">=": return df[df[col] >= val] if op == "<=": return df[df[col] <= val] if op == "=": return df[df[col] == val] return df def filter_all_datasets(datasets, queries, top_n=10): results = {} for df, name in datasets: df_filtered = df.copy() phase_filter = queries.get("phase", None) if phase_filter: phase_col = None for col in df_filtered.columns: if any(phase_key in col.lower() for phase_key in property_synonyms["phase"]): phase_col = col break if phase_col: df_filtered = df_filtered[df_filtered[phase_col].str.contains(phase_filter, case=False, na=False)] else: continue for prop, filter_val in queries.items(): if prop == "phase": continue col = find_column_for_property(df_filtered, prop) if col is None: df_filtered = None break df_filtered = df_filtered[df_filtered[col].notna()] df_filtered = apply_numeric_filter(df_filtered, col, filter_val) if df_filtered is None or df_filtered.empty: continue show_cols = [] if "formula" in df_filtered.columns: show_cols.append("formula") for prop in queries.keys(): if prop == "phase": continue col = find_column_for_property(df_filtered, prop) if col and col in df_filtered.columns: show_cols.append(col) if phase_filter and phase_col: show_cols.append(phase_col) show_cols = [c for c in show_cols if c in df_filtered.columns] df_filtered = df_filtered[show_cols].head(top_n).copy() df_filtered["Source"] = name results[name] = df_filtered return results # --- Main HEA query function --- def query_hea(question, top_k=5): # FAISS retrieval faiss_results = faiss_index.similarity_search(question, k=top_k) faiss_text = "\n".join([doc.page_content for doc in faiss_results]) # CSV filtering queries = parse_query_to_filters(question) csv_results_dict = filter_all_datasets( [(d1, "MPEA"), (d2, "MLPred"), (d3, "Achief")], queries, top_n=top_k ) csv_context = "" for name, df_filtered in csv_results_dict.items(): csv_context += f"\n### {name} matches:\n{df_filtered.to_string(index=False)}\n" # Prompt for Mistral prompt = f""" You are a materials scientist. Based on the following context, answer precisely. FAISS context: {faiss_text} CSV datasets context: {csv_context} Question: {question} Answer: """ # Tokenize and generate inputs = tokenizer(prompt, return_tensors="pt").to("cpu") outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.0) answer = tokenizer.decode(outputs[0], skip_special_tokens=True) # Merge CSV results merged_df = pd.concat(csv_results_dict.values(), ignore_index=True) if csv_results_dict else pd.DataFrame() return answer, merged_df, faiss_text # --- Gradio wrapper --- def gradio_query(question): return query_hea(question) # --- Launch Gradio interface --- demo = gr.Interface( fn=gradio_query, inputs=gr.Textbox(lines=2, placeholder="Ask about HEAs..."), outputs=[ gr.Textbox(label="LLM Answer"), gr.Dataframe(label="CSV Matches"), gr.Textbox(label="Paper Context (FAISS)") ], title="🔬 HEA Query", description="Query HEA datasets + FAISS paper embeddings" ) demo.launch()