Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import pandas as pd | |
| import re | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import os | |
| # --- Hugging Face token for gated models --- | |
| HF_TOKEN = os.environ["HF_API_TOKEN"] | |
| # --- Paths --- | |
| CSV_FOLDER = "data" | |
| FAISS_INDEX_PATH = "data/faiss_index_hugface_BAAI_new" | |
| # --- Load CSVs --- | |
| d1 = pd.read_csv(f"{CSV_FOLDER}/dataset1_clean.csv") | |
| d2 = pd.read_csv(f"{CSV_FOLDER}/dataset2_clean.csv") | |
| d3 = pd.read_csv(f"{CSV_FOLDER}/dataset3_clean.csv") | |
| print("✅ CSVs loaded") | |
| # --- Load FAISS --- | |
| embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en") | |
| faiss_index = FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True) | |
| print("✅ FAISS loaded") | |
| # --- Load Mistral model (CPU-friendly) --- | |
| MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.3" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN) | |
| print("✅ Mistral model loaded on CPU") | |
| # --- Property synonyms --- | |
| property_synonyms = { | |
| "hardness": ["hv", "hardness", "vicker's hardness", "vickers hardness"], | |
| "bulk modulus": ["d_bulk (gpa)", "bulk modulus", "bulk_modulus"], | |
| "yield_strength": ["ys", "yield stress", "yield strength"], | |
| "ultimate_strength": ["uts", "tensile strength", "ultimate tensile strength"], | |
| "phase": ["phase_label", "bcc/fcc/other", "phase", "microstructure"], | |
| "density": ["density_exp", "density_calc", "density"] | |
| } | |
| # --- Helper functions --- | |
| def find_column_for_property(df, property_name): | |
| synonyms = property_synonyms.get(property_name.lower(), [property_name]) | |
| for syn in synonyms: | |
| for col in df.columns: | |
| if syn.lower() in col.lower(): | |
| return col | |
| return None | |
| def parse_query_to_filters(question): | |
| query = question.lower() | |
| filters = {} | |
| for phase in ["fcc", "bcc", "hcp", "other"]: | |
| if phase in query: | |
| filters["phase"] = phase | |
| break | |
| numeric_props = ["hardness", "bulk modulus", "yield strength", "ultimate strength", "density"] | |
| for prop in numeric_props: | |
| pattern = rf"{prop}\s*(>=|<=|=|>|<)\s*(\d+\.?\d*)" | |
| match = re.search(pattern, query) | |
| if match: | |
| op, val = match.groups() | |
| filters[prop] = f"{op}{val}" | |
| if f"highest {prop}" in query or f"high {prop}" in query: | |
| filters[prop] = "high" | |
| elif f"lowest {prop}" in query or f"low {prop}" in query: | |
| filters[prop] = "low" | |
| return filters | |
| def apply_numeric_filter(df, col, filter_value): | |
| if filter_value == "high": | |
| return df.sort_values(by=col, ascending=False) | |
| elif filter_value == "low": | |
| return df.sort_values(by=col, ascending=True) | |
| else: | |
| match = re.match(r"(>=|<=|=|>|<)(\d+\.?\d*)", filter_value) | |
| if not match: | |
| return df | |
| op, val_str = match.groups() | |
| val = float(val_str) | |
| if op == ">": return df[df[col] > val] | |
| if op == "<": return df[df[col] < val] | |
| if op == ">=": return df[df[col] >= val] | |
| if op == "<=": return df[df[col] <= val] | |
| if op == "=": return df[df[col] == val] | |
| return df | |
| def filter_all_datasets(datasets, queries, top_n=10): | |
| results = {} | |
| for df, name in datasets: | |
| df_filtered = df.copy() | |
| phase_filter = queries.get("phase", None) | |
| if phase_filter: | |
| phase_col = None | |
| for col in df_filtered.columns: | |
| if any(phase_key in col.lower() for phase_key in property_synonyms["phase"]): | |
| phase_col = col | |
| break | |
| if phase_col: | |
| df_filtered = df_filtered[df_filtered[phase_col].str.contains(phase_filter, case=False, na=False)] | |
| else: | |
| continue | |
| for prop, filter_val in queries.items(): | |
| if prop == "phase": | |
| continue | |
| col = find_column_for_property(df_filtered, prop) | |
| if col is None: | |
| df_filtered = None | |
| break | |
| df_filtered = df_filtered[df_filtered[col].notna()] | |
| df_filtered = apply_numeric_filter(df_filtered, col, filter_val) | |
| if df_filtered is None or df_filtered.empty: | |
| continue | |
| show_cols = [] | |
| if "formula" in df_filtered.columns: | |
| show_cols.append("formula") | |
| for prop in queries.keys(): | |
| if prop == "phase": | |
| continue | |
| col = find_column_for_property(df_filtered, prop) | |
| if col and col in df_filtered.columns: | |
| show_cols.append(col) | |
| if phase_filter and phase_col: | |
| show_cols.append(phase_col) | |
| show_cols = [c for c in show_cols if c in df_filtered.columns] | |
| df_filtered = df_filtered[show_cols].head(top_n).copy() | |
| df_filtered["Source"] = name | |
| results[name] = df_filtered | |
| return results | |
| # --- Main HEA query function --- | |
| def query_hea(question, top_k=5): | |
| # FAISS retrieval | |
| faiss_results = faiss_index.similarity_search(question, k=top_k) | |
| faiss_text = "\n".join([doc.page_content for doc in faiss_results]) | |
| # CSV filtering | |
| queries = parse_query_to_filters(question) | |
| csv_results_dict = filter_all_datasets( | |
| [(d1, "MPEA"), (d2, "MLPred"), (d3, "Achief")], | |
| queries, | |
| top_n=top_k | |
| ) | |
| csv_context = "" | |
| for name, df_filtered in csv_results_dict.items(): | |
| csv_context += f"\n### {name} matches:\n{df_filtered.to_string(index=False)}\n" | |
| # Prompt for Mistral | |
| prompt = f""" | |
| You are a materials scientist. Based on the following context, answer precisely. | |
| FAISS context: {faiss_text} | |
| CSV datasets context: {csv_context} | |
| Question: {question} | |
| Answer: | |
| """ | |
| # Tokenize and generate | |
| inputs = tokenizer(prompt, return_tensors="pt").to("cpu") | |
| outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.0) | |
| answer = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Merge CSV results | |
| merged_df = pd.concat(csv_results_dict.values(), ignore_index=True) if csv_results_dict else pd.DataFrame() | |
| return answer, merged_df, faiss_text | |
| # --- Gradio wrapper --- | |
| def gradio_query(question): | |
| return query_hea(question) | |
| # --- Launch Gradio interface --- | |
| demo = gr.Interface( | |
| fn=gradio_query, | |
| inputs=gr.Textbox(lines=2, placeholder="Ask about HEAs..."), | |
| outputs=[ | |
| gr.Textbox(label="LLM Answer"), | |
| gr.Dataframe(label="CSV Matches"), | |
| gr.Textbox(label="Paper Context (FAISS)") | |
| ], | |
| title="🔬 HEA Query", | |
| description="Query HEA datasets + FAISS paper embeddings" | |
| ) | |
| demo.launch() | |