NSamson1 commited on
Commit
6f344a4
·
verified ·
1 Parent(s): 0f5b97b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -47
app.py CHANGED
@@ -10,7 +10,6 @@ from langchain_community.vectorstores import Chroma
10
  from langchain_core.prompts import PromptTemplate
11
  from langchain_core.output_parsers import StrOutputParser
12
  from langchain_core.runnables import RunnablePassthrough
13
- # Transformers and datasets
14
  from datasets import load_dataset
15
  from transformers import (
16
  AutoTokenizer,
@@ -18,57 +17,100 @@ from transformers import (
18
  pipeline,
19
  BitsAndBytesConfig
20
  )
 
21
 
22
  # ====================== CONFIGURATION ======================
23
  API_KEY = "Samson"
24
- MODEL_NAME = "microsoft/phi-2"
 
 
25
 
26
- # Set up logging
27
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
 
 
28
 
29
- # ---------------------- RAG Setup --------------------------
30
- # 1. Load dataset
31
- ds = load_dataset("maxpro291/bankfaqs_dataset")
32
- data = ds['train'][:]
33
- Bank_Data = pd.DataFrame({
34
- 'question': [entry for entry in data['text'] if entry.startswith("Q:")],
35
- 'answer': [entry for entry in data['text'] if entry.startswith("A:")]
36
- })
37
 
38
- # 2. Create vector store
39
- embed_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
40
- vectorstore = Chroma.from_texts(
41
- texts=[f"Q: {q}\nA: {a}" for q, a in zip(Bank_Data['question'], Bank_Data['answer'])],
42
- embedding=embed_model,
43
- persist_directory="./chroma_db_bank"
44
- )
45
- retriever = vectorstore.as_retriever()
 
 
 
 
 
46
 
47
- # 3. Initialize LLM
48
- quant_config = BitsAndBytesConfig(
49
- load_in_4bit=True,
50
- bnb_4bit_compute_dtype="float16",
51
- bnb_4bit_quant_type="nf4"
52
- )
53
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
54
- model = AutoModelForCausalLM.from_pretrained(
55
- MODEL_NAME,
56
- device_map="auto",
57
- trust_remote_code=True,
58
- quantization_config=quant_config
59
- )
 
 
60
 
61
- # Create LangChain pipeline
62
- llm_pipeline = pipeline(
63
- "text-generation",
64
- model=model,
65
- tokenizer=tokenizer,
66
- max_new_tokens=512,
67
- temperature=0.7,
68
- top_p=0.95
69
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- # 4. Build RAG chain
 
 
72
  template = """You are a banking assistant. Use context if relevant:
73
  Context: {context}
74
  Question: {question}
@@ -82,7 +124,9 @@ rag_chain = (
82
  | StrOutputParser()
83
  )
84
 
85
- # ---------------------- API Setup --------------------------
 
 
86
  app = FastAPI()
87
 
88
  def validate_api_key(api_key: str = Header(None)):
@@ -98,7 +142,9 @@ async def chat_endpoint(question: str, authorization: str = Header(None)):
98
  response += chunk
99
  return {"response": response}
100
 
101
- # -------------------- Gradio Interface ---------------------
 
 
102
  def respond(message, history):
103
  return next(rag_chain.stream(message))
104
 
@@ -113,11 +159,15 @@ demo = gr.ChatInterface(
113
  theme="glass"
114
  )
115
 
116
- # --------------------- Launch Servers ----------------------
 
 
117
  if __name__ == "__main__":
 
118
  threading.Thread(
119
  target=demo.launch,
120
- kwargs={"server_name": "0.0.0.0", "server_port": 7860}
121
  ).start()
122
 
 
123
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
10
  from langchain_core.prompts import PromptTemplate
11
  from langchain_core.output_parsers import StrOutputParser
12
  from langchain_core.runnables import RunnablePassthrough
 
13
  from datasets import load_dataset
14
  from transformers import (
15
  AutoTokenizer,
 
17
  pipeline,
18
  BitsAndBytesConfig
19
  )
20
+ import torch # Explicitly imported for CUDA management
21
 
22
  # ====================== CONFIGURATION ======================
23
  API_KEY = "Samson"
24
+ MODEL_NAME = "microsoft/phi-2"
25
+ EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
26
+ # ===========================================================
27
 
28
+ # Configure logging
29
+ logging.basicConfig(
30
+ level=logging.INFO,
31
+ format='%(asctime)s - %(levelname)s - %(message)s'
32
+ )
33
 
34
+ # Clear CUDA cache if using GPU
35
+ if torch.cuda.is_available():
36
+ torch.cuda.empty_cache()
 
 
 
 
 
37
 
38
+ # ------------------------------------------------------------------
39
+ # 1. Load and Prepare Dataset
40
+ # ------------------------------------------------------------------
41
+ def load_data():
42
+ try:
43
+ ds = load_dataset("maxpro291/bankfaqs_dataset")
44
+ data = ds['train'][:]
45
+ questions = [entry for entry in data['text'] if entry.startswith("Q:")]
46
+ answers = [entry for entry in data['text'] if entry.startswith("A:")]
47
+ return pd.DataFrame({'question': questions, 'answer': answers})
48
+ except Exception as e:
49
+ logging.error(f"Error loading dataset: {str(e)}")
50
+ raise
51
 
52
+ # ------------------------------------------------------------------
53
+ # 2. Initialize Embeddings and Vector Store
54
+ # ------------------------------------------------------------------
55
+ def init_vectordb(data):
56
+ try:
57
+ embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
58
+ texts = [f"Q: {q}\nA: {a}" for q, a in zip(data['question'], data['answer'])]
59
+ return Chroma.from_texts(
60
+ texts=texts,
61
+ embedding=embeddings,
62
+ persist_directory="./chroma_db_bank"
63
+ )
64
+ except Exception as e:
65
+ logging.error(f"Error initializing vector store: {str(e)}")
66
+ raise
67
 
68
+ # ------------------------------------------------------------------
69
+ # 3. Initialize LLM with Quantization
70
+ # ------------------------------------------------------------------
71
+ def load_llm():
72
+ try:
73
+ quantization_config = BitsAndBytesConfig(
74
+ load_in_4bit=True,
75
+ bnb_4bit_use_double_quant=True,
76
+ bnb_4bit_quant_type="nf4",
77
+ bnb_4bit_compute_dtype="float16"
78
+ )
79
+
80
+ tokenizer = AutoTokenizer.from_pretrained(
81
+ MODEL_NAME,
82
+ trust_remote_code=True,
83
+ padding_side="left" # Critical for phi-2
84
+ )
85
+
86
+ model = AutoModelForCausalLM.from_pretrained(
87
+ MODEL_NAME,
88
+ device_map="auto",
89
+ trust_remote_code=True,
90
+ quantization_config=quantization_config
91
+ )
92
+
93
+ return pipeline(
94
+ "text-generation",
95
+ model=model,
96
+ tokenizer=tokenizer,
97
+ max_new_tokens=512,
98
+ temperature=0.7,
99
+ top_p=0.95,
100
+ do_sample=True
101
+ )
102
+ except Exception as e:
103
+ logging.error(f"Error loading LLM: {str(e)}")
104
+ raise
105
+
106
+ # Initialize components
107
+ bank_data = load_data()
108
+ retriever = init_vectordb(bank_data).as_retriever()
109
+ llm_pipeline = load_llm()
110
 
111
+ # ------------------------------------------------------------------
112
+ # 4. Build RAG Chain
113
+ # ------------------------------------------------------------------
114
  template = """You are a banking assistant. Use context if relevant:
115
  Context: {context}
116
  Question: {question}
 
124
  | StrOutputParser()
125
  )
126
 
127
+ # ------------------------------------------------------------------
128
+ # 5. FastAPI Setup
129
+ # ------------------------------------------------------------------
130
  app = FastAPI()
131
 
132
  def validate_api_key(api_key: str = Header(None)):
 
142
  response += chunk
143
  return {"response": response}
144
 
145
+ # ------------------------------------------------------------------
146
+ # 6. Gradio Interface
147
+ # ------------------------------------------------------------------
148
  def respond(message, history):
149
  return next(rag_chain.stream(message))
150
 
 
159
  theme="glass"
160
  )
161
 
162
+ # ------------------------------------------------------------------
163
+ # 7. Launch Servers
164
+ # ------------------------------------------------------------------
165
  if __name__ == "__main__":
166
+ # Start Gradio in separate thread
167
  threading.Thread(
168
  target=demo.launch,
169
+ kwargs={"server_name": "0.0.0.0", "server_port": 7860, "share": False}
170
  ).start()
171
 
172
+ # Start FastAPI
173
  uvicorn.run(app, host="0.0.0.0", port=8000)