taradutt007 commited on
Commit
dcaf215
·
verified ·
1 Parent(s): 8d96f02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -26
app.py CHANGED
@@ -1,31 +1,33 @@
1
- import os
2
  import gradio as gr
3
  import pandas as pd
4
  import re
5
- from langchain_community.vectorstores import FAISS
6
  from langchain_community.embeddings import HuggingFaceEmbeddings
7
- from huggingface_hub import InferenceClient
 
 
8
 
9
- # === Paths ===
10
  CSV_FOLDER = "data"
11
  FAISS_INDEX_PATH = "data/faiss_index_hugface_BAAI_new"
12
 
13
- # === Load CSVs ===
14
  d1 = pd.read_csv(f"{CSV_FOLDER}/dataset1_clean.csv")
15
  d2 = pd.read_csv(f"{CSV_FOLDER}/dataset2_clean.csv")
16
  d3 = pd.read_csv(f"{CSV_FOLDER}/dataset3_clean.csv")
17
  print("✅ CSVs loaded")
18
 
19
- # === Load FAISS ===
20
  embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en")
21
  faiss_index = FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True)
22
  print("✅ FAISS loaded")
23
 
24
- # === Hugging Face Inference API ===
25
- HF_API_TOKEN = os.environ.get("HF_API_TOKEN") # set this in Space secrets or locally
26
- client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.3", token=HF_API_TOKEN)
 
 
27
 
28
- # === Property synonyms ===
29
  property_synonyms = {
30
  "hardness": ["hv", "hardness", "vicker's hardness", "vickers hardness"],
31
  "bulk modulus": ["d_bulk (gpa)", "bulk modulus", "bulk_modulus"],
@@ -35,6 +37,7 @@ property_synonyms = {
35
  "density": ["density_exp", "density_calc", "density"]
36
  }
37
 
 
38
  def find_column_for_property(df, property_name):
39
  synonyms = property_synonyms.get(property_name.lower(), [property_name])
40
  for syn in synonyms:
@@ -86,8 +89,8 @@ def filter_all_datasets(datasets, queries, top_n=10):
86
  for df, name in datasets:
87
  df_filtered = df.copy()
88
  phase_filter = queries.get("phase", None)
89
- phase_col = None
90
  if phase_filter:
 
91
  for col in df_filtered.columns:
92
  if any(phase_key in col.lower() for phase_key in property_synonyms["phase"]):
93
  phase_col = col
@@ -107,26 +110,29 @@ def filter_all_datasets(datasets, queries, top_n=10):
107
  df_filtered = apply_numeric_filter(df_filtered, col, filter_val)
108
  if df_filtered is None or df_filtered.empty:
109
  continue
110
- show_cols = ["formula"] if "formula" in df_filtered.columns else []
 
 
111
  for prop in queries.keys():
112
  if prop == "phase":
113
  continue
114
  col = find_column_for_property(df_filtered, prop)
115
- if col and col in df_filtered.columns and col not in show_cols:
116
  show_cols.append(col)
117
- if phase_filter and phase_col and phase_col not in show_cols:
118
  show_cols.append(phase_col)
 
119
  df_filtered = df_filtered[show_cols].head(top_n).copy()
120
  df_filtered["Source"] = name
121
  results[name] = df_filtered
122
  return results
123
 
124
- # === Main Query Function ===
125
  def query_hea(question, top_k=5):
126
  # FAISS retrieval
127
  faiss_results = faiss_index.similarity_search(question, k=top_k)
128
  faiss_text = "\n".join([doc.page_content for doc in faiss_results])
129
-
130
  # CSV filtering
131
  queries = parse_query_to_filters(question)
132
  csv_results_dict = filter_all_datasets(
@@ -134,12 +140,11 @@ def query_hea(question, top_k=5):
134
  queries,
135
  top_n=top_k
136
  )
137
-
138
  csv_context = ""
139
  for name, df_filtered in csv_results_dict.items():
140
  csv_context += f"\n### {name} matches:\n{df_filtered.to_string(index=False)}\n"
141
 
142
- # Prompt for Mistral Instruct conversational API
143
  prompt = f"""
144
  You are a materials scientist. Based on the following context, answer precisely.
145
  FAISS context: {faiss_text}
@@ -147,18 +152,22 @@ CSV datasets context: {csv_context}
147
  Question: {question}
148
  Answer:
149
  """
150
- # Conversational API requires role-based input
151
- conversation_input = [{"role": "user", "content": prompt}]
152
- response = client.conversation(conversation_input)
153
- output_text = response[0]["generated_text"]
154
 
 
 
 
 
 
 
155
  merged_df = pd.concat(csv_results_dict.values(), ignore_index=True) if csv_results_dict else pd.DataFrame()
156
- return output_text, merged_df, faiss_text
157
 
 
 
 
158
  def gradio_query(question):
159
  return query_hea(question)
160
 
161
- # === Gradio Interface ===
162
  demo = gr.Interface(
163
  fn=gradio_query,
164
  inputs=gr.Textbox(lines=2, placeholder="Ask about HEAs..."),
@@ -171,6 +180,5 @@ demo = gr.Interface(
171
  description="Query HEA datasets + FAISS paper embeddings"
172
  )
173
 
174
- if __name__ == "__main__":
175
- demo.launch()
176
 
 
 
1
  import gradio as gr
2
  import pandas as pd
3
  import re
 
4
  from langchain_community.embeddings import HuggingFaceEmbeddings
5
+ from langchain_community.vectorstores import FAISS
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ import torch
8
 
9
+ # --- Paths ---
10
  CSV_FOLDER = "data"
11
  FAISS_INDEX_PATH = "data/faiss_index_hugface_BAAI_new"
12
 
13
+ # --- Load CSVs ---
14
  d1 = pd.read_csv(f"{CSV_FOLDER}/dataset1_clean.csv")
15
  d2 = pd.read_csv(f"{CSV_FOLDER}/dataset2_clean.csv")
16
  d3 = pd.read_csv(f"{CSV_FOLDER}/dataset3_clean.csv")
17
  print("✅ CSVs loaded")
18
 
19
+ # --- Load FAISS with dummy embeddings ---
20
  embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en")
21
  faiss_index = FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True)
22
  print("✅ FAISS loaded")
23
 
24
+ # --- Load Mistral model ---
25
+ MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.3"
26
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
27
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto")
28
+ print("✅ Mistral model loaded")
29
 
30
+ # --- Property synonyms ---
31
  property_synonyms = {
32
  "hardness": ["hv", "hardness", "vicker's hardness", "vickers hardness"],
33
  "bulk modulus": ["d_bulk (gpa)", "bulk modulus", "bulk_modulus"],
 
37
  "density": ["density_exp", "density_calc", "density"]
38
  }
39
 
40
+ # --- Helper functions ---
41
  def find_column_for_property(df, property_name):
42
  synonyms = property_synonyms.get(property_name.lower(), [property_name])
43
  for syn in synonyms:
 
89
  for df, name in datasets:
90
  df_filtered = df.copy()
91
  phase_filter = queries.get("phase", None)
 
92
  if phase_filter:
93
+ phase_col = None
94
  for col in df_filtered.columns:
95
  if any(phase_key in col.lower() for phase_key in property_synonyms["phase"]):
96
  phase_col = col
 
110
  df_filtered = apply_numeric_filter(df_filtered, col, filter_val)
111
  if df_filtered is None or df_filtered.empty:
112
  continue
113
+ show_cols = []
114
+ if "formula" in df_filtered.columns:
115
+ show_cols.append("formula")
116
  for prop in queries.keys():
117
  if prop == "phase":
118
  continue
119
  col = find_column_for_property(df_filtered, prop)
120
+ if col and col in df_filtered.columns:
121
  show_cols.append(col)
122
+ if phase_filter and phase_col:
123
  show_cols.append(phase_col)
124
+ show_cols = [c for c in show_cols if c in df_filtered.columns]
125
  df_filtered = df_filtered[show_cols].head(top_n).copy()
126
  df_filtered["Source"] = name
127
  results[name] = df_filtered
128
  return results
129
 
130
+ # --- Main HEA query function ---
131
  def query_hea(question, top_k=5):
132
  # FAISS retrieval
133
  faiss_results = faiss_index.similarity_search(question, k=top_k)
134
  faiss_text = "\n".join([doc.page_content for doc in faiss_results])
135
+
136
  # CSV filtering
137
  queries = parse_query_to_filters(question)
138
  csv_results_dict = filter_all_datasets(
 
140
  queries,
141
  top_n=top_k
142
  )
 
143
  csv_context = ""
144
  for name, df_filtered in csv_results_dict.items():
145
  csv_context += f"\n### {name} matches:\n{df_filtered.to_string(index=False)}\n"
146
 
147
+ # --- Prompt for Mistral ---
148
  prompt = f"""
149
  You are a materials scientist. Based on the following context, answer precisely.
150
  FAISS context: {faiss_text}
 
152
  Question: {question}
153
  Answer:
154
  """
 
 
 
 
155
 
156
+ # Tokenize and generate
157
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
158
+ outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.0)
159
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
160
+
161
+ # Merge CSV results
162
  merged_df = pd.concat(csv_results_dict.values(), ignore_index=True) if csv_results_dict else pd.DataFrame()
 
163
 
164
+ return answer, merged_df, faiss_text
165
+
166
+ # --- Gradio wrapper ---
167
  def gradio_query(question):
168
  return query_hea(question)
169
 
170
+ # --- Launch Gradio interface ---
171
  demo = gr.Interface(
172
  fn=gradio_query,
173
  inputs=gr.Textbox(lines=2, placeholder="Ask about HEAs..."),
 
180
  description="Query HEA datasets + FAISS paper embeddings"
181
  )
182
 
183
+ demo.launch()
 
184