NSamson1 commited on
Commit
174562d
·
verified ·
1 Parent(s): 682799b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -77
app.py CHANGED
@@ -9,65 +9,49 @@ from langchain_core.output_parsers import StrOutputParser
9
  from langchain_core.runnables import RunnablePassthrough
10
  import gradio as gr
11
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
 
 
 
 
 
 
 
 
12
 
13
  # Set up logging
14
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
15
 
16
- # ------------------------------------------------------------------
17
- # 1. Load and Prepare the Bank FAQ Dataset (UNCHANGED)
18
- # ------------------------------------------------------------------
19
  ds = load_dataset("maxpro291/bankfaqs_dataset")
20
- train_ds = ds['train']
21
- data = train_ds[:] # load all examples
22
-
23
- questions = []
24
- answers = []
25
- for entry in data['text']:
26
- if entry.startswith("Q:"):
27
- questions.append(entry)
28
- elif entry.startswith("A:"):
29
- answers.append(entry)
30
-
31
- Bank_Data = pd.DataFrame({'question': questions, 'answer': answers})
32
-
33
- context_data = []
34
- for i in range(len(Bank_Data)):
35
- context = f"Question: {Bank_Data.iloc[i]['question']} Answer: {Bank_Data.iloc[i]['answer']}"
36
- context_data.append(context)
37
 
38
- # ------------------------------------------------------------------
39
- # 2. Create the Vector Store for Retrieval (UNCHANGED)
40
- # ------------------------------------------------------------------
41
  embed_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
42
-
43
  vectorstore = Chroma.from_texts(
44
- texts=context_data,
45
  embedding=embed_model,
46
  persist_directory="./chroma_db_bank"
47
  )
48
  retriever = vectorstore.as_retriever()
49
 
50
- # ------------------------------------------------------------------
51
- # 3. Initialize PHI-2 Model (MODIFIED SECTION)
52
- # ------------------------------------------------------------------
53
- model_name = "microsoft/phi-2"
54
-
55
- # Configure 4-bit quantization for efficient loading
56
- quantization_config = BitsAndBytesConfig(
57
  load_in_4bit=True,
58
  bnb_4bit_compute_dtype="float16",
59
  bnb_4bit_quant_type="nf4"
60
  )
61
-
62
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
63
  model = AutoModelForCausalLM.from_pretrained(
64
- model_name,
65
  device_map="auto",
66
  trust_remote_code=True,
67
- quantization_config=quantization_config
68
  )
69
-
70
- # Create text-generation pipeline with Phi-2 specific settings
71
  pipe = pipeline(
72
  "text-generation",
73
  model=model,
@@ -75,25 +59,16 @@ pipe = pipeline(
75
  max_new_tokens=512,
76
  temperature=0.7,
77
  top_p=0.95,
78
- repetition_penalty=1.15,
79
- do_sample=True
80
  )
81
-
82
- # Wrap the pipeline in LangChain's HuggingFacePipeline
83
  huggingface_model = HuggingFacePipeline(pipeline=pipe)
84
 
85
- # ------------------------------------------------------------------
86
- # 4. Build the RAG Chain (UNCHANGED)
87
- # ------------------------------------------------------------------
88
- template = (
89
- "You are a helpful banking assistant. "
90
- "Use the provided context if it is relevant to answer the question. "
91
- "If not, answer using your general banking knowledge.\n"
92
- "Question: {question}\n"
93
- "Answer:"
94
- )
95
  rag_prompt = PromptTemplate.from_template(template)
96
-
97
  rag_chain = (
98
  {"context": retriever, "question": RunnablePassthrough()}
99
  | rag_prompt
@@ -101,38 +76,56 @@ rag_chain = (
101
  | StrOutputParser()
102
  )
103
 
104
- # ------------------------------------------------------------------
105
- # 5. Gradio Chat Interface (UNCHANGED)
106
- # ------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  def rag_memory_stream(message, history):
108
  partial_text = ""
109
  for new_text in rag_chain.stream(message):
110
  partial_text += new_text
111
  yield partial_text
112
 
113
- examples = [
114
- "I want to open an account",
115
- "What is a savings account?",
116
- "How do I use an ATM?",
117
- "How can I resolve a bank account issue?"
118
- ]
119
-
120
- title = "Your Personal Banking Assistant 💬"
121
- description = (
122
- "Welcome! I'm here to answer your questions about banking and related topics. "
123
- "Ask me anything, and I'll do my best to assist you."
124
- )
125
-
126
  demo = gr.ChatInterface(
127
  fn=rag_memory_stream,
128
- title=title,
129
- description=description,
130
- examples=examples,
131
- theme="glass",
 
 
 
 
132
  )
133
 
134
- # ------------------------------------------------------------------
135
- # 6. Launch the App (UNCHANGED)
136
- # ------------------------------------------------------------------
 
137
  if __name__ == "__main__":
138
- demo.launch(share=True)
 
 
 
 
 
 
9
  from langchain_core.runnables import RunnablePassthrough
10
  import gradio as gr
11
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
12
+ from fastapi import FastAPI, Header, HTTPException
13
+ import threading
14
+ import uvicorn
15
+
16
+ # ====================== CONFIGURATION ======================
17
+ API_KEY = "Samson" # Your hardcoded API key
18
+ MODEL_NAME = "microsoft/phi-2" # Using Phi-2 model
19
+ # ===========================================================
20
 
21
  # Set up logging
22
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
23
 
24
+ # ---------------------- RAG Setup --------------------------
25
+ # 1. Load and prepare dataset
 
26
  ds = load_dataset("maxpro291/bankfaqs_dataset")
27
+ data = ds['train'][:]
28
+ Bank_Data = pd.DataFrame({
29
+ 'question': [entry for entry in data['text'] if entry.startswith("Q:")],
30
+ 'answer': [entry for entry in data['text'] if entry.startswith("A:")]
31
+ })
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ # 2. Create vector store
 
 
34
  embed_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
35
  vectorstore = Chroma.from_texts(
36
+ texts=[f"Q: {q}\nA: {a}" for q, a in zip(Bank_Data['question'], Bank_Data['answer'])],
37
  embedding=embed_model,
38
  persist_directory="./chroma_db_bank"
39
  )
40
  retriever = vectorstore.as_retriever()
41
 
42
+ # 3. Initialize LLM with 4-bit quantization
43
+ quant_config = BitsAndBytesConfig(
 
 
 
 
 
44
  load_in_4bit=True,
45
  bnb_4bit_compute_dtype="float16",
46
  bnb_4bit_quant_type="nf4"
47
  )
48
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
 
49
  model = AutoModelForCausalLM.from_pretrained(
50
+ MODEL_NAME,
51
  device_map="auto",
52
  trust_remote_code=True,
53
+ quantization_config=quant_config
54
  )
 
 
55
  pipe = pipeline(
56
  "text-generation",
57
  model=model,
 
59
  max_new_tokens=512,
60
  temperature=0.7,
61
  top_p=0.95,
62
+ repetition_penalty=1.15
 
63
  )
 
 
64
  huggingface_model = HuggingFacePipeline(pipeline=pipe)
65
 
66
+ # 4. Build RAG chain
67
+ template = """You are a banking assistant. Use context if relevant:
68
+ Context: {context}
69
+ Question: {question}
70
+ Answer:"""
 
 
 
 
 
71
  rag_prompt = PromptTemplate.from_template(template)
 
72
  rag_chain = (
73
  {"context": retriever, "question": RunnablePassthrough()}
74
  | rag_prompt
 
76
  | StrOutputParser()
77
  )
78
 
79
+ # ---------------------- API Setup --------------------------
80
+ app = FastAPI()
81
+
82
+ def validate_api_key(api_key: str = Header(None)):
83
+ if api_key != API_KEY:
84
+ raise HTTPException(status_code=401, detail="Invalid API Key")
85
+ return True
86
+
87
+ @app.post("/chat")
88
+ async def chat_endpoint(
89
+ question: str,
90
+ authorization: str = Header(None),
91
+ ):
92
+ validate_api_key(authorization)
93
+ response = ""
94
+ for chunk in rag_chain.stream(question):
95
+ response += chunk
96
+ return {"response": response}
97
+
98
+ @app.get("/health")
99
+ async def health_check():
100
+ return {"status": "healthy"}
101
+
102
+ # -------------------- Gradio Interface ---------------------
103
  def rag_memory_stream(message, history):
104
  partial_text = ""
105
  for new_text in rag_chain.stream(message):
106
  partial_text += new_text
107
  yield partial_text
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  demo = gr.ChatInterface(
110
  fn=rag_memory_stream,
111
+ title="Banking Assistant 🔒 (API Key: Samson)",
112
+ description="Welcome! Use API key 'Samson' to access the /chat endpoint",
113
+ examples=[
114
+ "How do I open an account?",
115
+ "What's the interest rate for savings?",
116
+ "How do I apply for a loan?"
117
+ ],
118
+ theme="glass"
119
  )
120
 
121
+ # --------------------- Launch Servers ----------------------
122
+ def run_gradio():
123
+ demo.launch(server_name="0.0.0.0", server_port=7860)
124
+
125
  if __name__ == "__main__":
126
+ # Start Gradio in separate thread
127
+ gradio_thread = threading.Thread(target=run_gradio)
128
+ gradio_thread.start()
129
+
130
+ # Start FastAPI
131
+ uvicorn.run(app, host="0.0.0.0", port=8000)