Mohamed284 commited on
Commit
7918f2a
·
verified ·
1 Parent(s): d2ff146

stanford-crfm/BioMedLM

Browse files
Files changed (1) hide show
  1. app.py +31 -27
app.py CHANGED
@@ -1,26 +1,23 @@
1
  import os
2
  import json
3
  import pandas as pd
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
- from langchain_ollama import OllamaLLM, OllamaEmbeddings
6
  from langchain_community.vectorstores import FAISS
7
  from langchain_core.prompts import PromptTemplate
8
  from langchain_core.output_parsers import StrOutputParser
9
  from operator import itemgetter
10
- from huggingface_hub import HfApi, HfFolder
11
  import gradio as gr
12
- from huggingface_hub import login, InferenceClient
13
- from langchain_community.embeddings import HuggingFaceEmbeddings # Updated import
14
-
15
 
 
16
  USE_HF = True
17
- MODEL_NAME = "BioMistral/BioMistral-7B"
 
18
 
 
19
  with open('AskNatureNet_data.json', 'r', encoding='utf-8') as f:
20
  data = json.load(f)
21
-
22
  df = pd.DataFrame(data)
23
-
24
  documents = [
25
  f"Source: {item['Source']}\nApplication: {item['Application']}\nFunction1: {item['Function1']}\nStrategy: {item['Strategy']}"
26
  for item in data
@@ -28,31 +25,39 @@ documents = [
28
 
29
  if USE_HF:
30
  print("Using Hugging Face model...")
31
-
32
  huggingface_token = os.environ.get("AskNature_RAG")
33
- # Load tokenizer and model from Hugging Face Hub
34
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=huggingface_token)
 
 
 
 
 
 
 
 
35
  model = AutoModelForCausalLM.from_pretrained(
36
  MODEL_NAME,
37
  device_map="auto",
38
  offload_folder="offload", # Specify the offload folder
39
- token=huggingface_token
 
40
  )
41
  embeddings = HuggingFaceEmbeddings(model_name=MODEL_NAME)
42
  lang_model = model
43
  else:
44
- print("Using local Ollama model...")
45
- MODEL = "jsk/bio-mistral"
46
- embeddings = OllamaEmbeddings(model=MODEL)
47
- lang_model = OllamaLLM(model=MODEL)
48
 
49
- batch_size = 16
50
  batched_embeddings = [
51
- embeddings.embed_documents(documents[i:i + batch_size])
52
- for i in range(0, len(documents), batch_size)
53
  ]
54
  batched_embeddings = [embed for batch in batched_embeddings for embed in batch]
55
 
 
56
  index_path = "faiss_index"
57
  if os.path.exists(index_path):
58
  vectorstore = FAISS.load_local(index_path, embeddings)
@@ -62,30 +67,29 @@ else:
62
 
63
  retriever = vectorstore.as_retriever()
64
 
 
65
  template = """
66
  Answer the question based on the context below. If you can't
67
  answer the question, reply "I don't know".
68
-
69
  Context: {context}
70
-
71
  Question: {question}
72
  """
73
  prompt = PromptTemplate.from_template(template)
74
 
 
75
  chain = {
76
  "context": itemgetter("question") | retriever,
77
  "question": itemgetter("question"),
78
  } | prompt | lang_model | StrOutputParser()
79
 
 
80
  def rag_qa(question):
81
  try:
82
  return chain.invoke({'question': question})
83
  except Exception as e:
84
  return f"Error: {str(e)}"
85
 
86
- # Chatbot functionality
87
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
88
-
89
  def respond(
90
  message,
91
  history: list[tuple[str, str]],
@@ -114,10 +118,10 @@ def respond(
114
  top_p=top_p,
115
  ):
116
  token = message.choices[0].delta.content
117
-
118
  response += token
119
  yield response
120
 
 
121
  demo = gr.ChatInterface(
122
  respond,
123
  additional_inputs=[
@@ -135,4 +139,4 @@ demo = gr.ChatInterface(
135
  )
136
 
137
  if __name__ == "__main__":
138
- demo.launch()
 
1
  import os
2
  import json
3
  import pandas as pd
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
 
5
  from langchain_community.vectorstores import FAISS
6
  from langchain_core.prompts import PromptTemplate
7
  from langchain_core.output_parsers import StrOutputParser
8
  from operator import itemgetter
 
9
  import gradio as gr
10
+ from langchain_community.embeddings import HuggingFaceEmbeddings
 
 
11
 
12
+ # Configuration
13
  USE_HF = True
14
+ MODEL_NAME = "stanford-crfm/BioMedLM"
15
+ BATCH_SIZE = 8 # Adjusted batch size for memory optimization
16
 
17
+ # Load data
18
  with open('AskNatureNet_data.json', 'r', encoding='utf-8') as f:
19
  data = json.load(f)
 
20
  df = pd.DataFrame(data)
 
21
  documents = [
22
  f"Source: {item['Source']}\nApplication: {item['Application']}\nFunction1: {item['Function1']}\nStrategy: {item['Strategy']}"
23
  for item in data
 
25
 
26
  if USE_HF:
27
  print("Using Hugging Face model...")
28
+
29
  huggingface_token = os.environ.get("AskNature_RAG")
30
+
31
+ # Quantization configuration for 4-bit precision
32
+ bnb_config = BitsAndBytesConfig(
33
+ load_in_4bit=True,
34
+ bnb_4bit_use_double_quant=True,
35
+ bnb_4bit_quant_type="nf4"
36
+ )
37
+
38
+ # Load tokenizer and model with offloading and quantization
39
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=huggingface_token)
40
  model = AutoModelForCausalLM.from_pretrained(
41
  MODEL_NAME,
42
  device_map="auto",
43
  offload_folder="offload", # Specify the offload folder
44
+ quantization_config=bnb_config,
45
+ use_auth_token=huggingface_token
46
  )
47
  embeddings = HuggingFaceEmbeddings(model_name=MODEL_NAME)
48
  lang_model = model
49
  else:
50
+ print("Using local model...")
51
+ # Local model loading logic here
 
 
52
 
53
+ # Generate embeddings in batches
54
  batched_embeddings = [
55
+ embeddings.embed_documents(documents[i:i + BATCH_SIZE])
56
+ for i in range(0, len(documents), BATCH_SIZE)
57
  ]
58
  batched_embeddings = [embed for batch in batched_embeddings for embed in batch]
59
 
60
+ # FAISS index handling
61
  index_path = "faiss_index"
62
  if os.path.exists(index_path):
63
  vectorstore = FAISS.load_local(index_path, embeddings)
 
67
 
68
  retriever = vectorstore.as_retriever()
69
 
70
+ # Prompt template
71
  template = """
72
  Answer the question based on the context below. If you can't
73
  answer the question, reply "I don't know".
 
74
  Context: {context}
 
75
  Question: {question}
76
  """
77
  prompt = PromptTemplate.from_template(template)
78
 
79
+ # Chain definition
80
  chain = {
81
  "context": itemgetter("question") | retriever,
82
  "question": itemgetter("question"),
83
  } | prompt | lang_model | StrOutputParser()
84
 
85
+ # Question-answering function
86
  def rag_qa(question):
87
  try:
88
  return chain.invoke({'question': question})
89
  except Exception as e:
90
  return f"Error: {str(e)}"
91
 
92
+ # Gradio chatbot interface
 
 
93
  def respond(
94
  message,
95
  history: list[tuple[str, str]],
 
118
  top_p=top_p,
119
  ):
120
  token = message.choices[0].delta.content
 
121
  response += token
122
  yield response
123
 
124
+ # Gradio interface setup
125
  demo = gr.ChatInterface(
126
  respond,
127
  additional_inputs=[
 
139
  )
140
 
141
  if __name__ == "__main__":
142
+ demo.launch()