NSamson1 commited on
Commit
b053964
·
verified ·
1 Parent(s): 3de49cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -40
app.py CHANGED
@@ -2,11 +2,11 @@ import os
2
  import pandas as pd
3
  import logging
4
  from datasets import load_dataset
5
- from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
6
- from langchain_chroma import Chroma
7
- from langchain_core.prompts import PromptTemplate
8
- from langchain_core.output_parsers import StrOutputParser
9
- from langchain_core.runnables import RunnablePassthrough
10
  import gradio as gr
11
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
12
  from fastapi import FastAPI, Header, HTTPException
@@ -14,15 +14,15 @@ import threading
14
  import uvicorn
15
 
16
  # ====================== CONFIGURATION ======================
17
- API_KEY = "Samson" # Your hardcoded API key
18
- MODEL_NAME = "microsoft/phi-2" # Using Phi-2 model
19
  # ===========================================================
20
 
21
  # Set up logging
22
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
23
 
24
  # ---------------------- RAG Setup --------------------------
25
- # 1. Load and prepare dataset
26
  ds = load_dataset("maxpro291/bankfaqs_dataset")
27
  data = ds['train'][:]
28
  Bank_Data = pd.DataFrame({
@@ -39,7 +39,7 @@ vectorstore = Chroma.from_texts(
39
  )
40
  retriever = vectorstore.as_retriever()
41
 
42
- # 3. Initialize LLM with 4-bit quantization
43
  quant_config = BitsAndBytesConfig(
44
  load_in_4bit=True,
45
  bnb_4bit_compute_dtype="float16",
@@ -52,27 +52,28 @@ model = AutoModelForCausalLM.from_pretrained(
52
  trust_remote_code=True,
53
  quantization_config=quant_config
54
  )
55
- pipe = pipeline(
 
 
56
  "text-generation",
57
  model=model,
58
  tokenizer=tokenizer,
59
  max_new_tokens=512,
60
  temperature=0.7,
61
- top_p=0.95,
62
- repetition_penalty=1.15
63
  )
64
- huggingface_model = HuggingFacePipeline(pipeline=pipe)
65
 
66
  # 4. Build RAG chain
67
  template = """You are a banking assistant. Use context if relevant:
68
  Context: {context}
69
  Question: {question}
70
  Answer:"""
71
- rag_prompt = PromptTemplate.from_template(template)
 
72
  rag_chain = (
73
  {"context": retriever, "question": RunnablePassthrough()}
74
- | rag_prompt
75
- | huggingface_model
76
  | StrOutputParser()
77
  )
78
 
@@ -85,47 +86,33 @@ def validate_api_key(api_key: str = Header(None)):
85
  return True
86
 
87
  @app.post("/chat")
88
- async def chat_endpoint(
89
- question: str,
90
- authorization: str = Header(None),
91
- ):
92
  validate_api_key(authorization)
93
  response = ""
94
  for chunk in rag_chain.stream(question):
95
  response += chunk
96
  return {"response": response}
97
 
98
- @app.get("/health")
99
- async def health_check():
100
- return {"status": "healthy"}
101
-
102
  # -------------------- Gradio Interface ---------------------
103
- def rag_memory_stream(message, history):
104
- partial_text = ""
105
- for new_text in rag_chain.stream(message):
106
- partial_text += new_text
107
- yield partial_text
108
 
109
  demo = gr.ChatInterface(
110
- fn=rag_memory_stream,
111
- title="Banking Assistant 🔒 (API Key: Samson)",
112
- description="Welcome! Use API key 'Samson' to access the /chat endpoint",
113
  examples=[
114
  "How do I open an account?",
115
- "What's the interest rate for savings?",
116
  "How do I apply for a loan?"
117
  ],
118
  theme="glass"
119
  )
120
 
121
  # --------------------- Launch Servers ----------------------
122
- def run_gradio():
123
- demo.launch(server_name="0.0.0.0", server_port=7860)
124
-
125
  if __name__ == "__main__":
126
- # Start Gradio in separate thread
127
- gradio_thread = threading.Thread(target=run_gradio)
128
- gradio_thread.start()
 
129
 
130
- # Start FastAPI
131
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
2
  import pandas as pd
3
  import logging
4
  from datasets import load_dataset
5
+ from langchain.embeddings import HuggingFaceEmbeddings # Updated import
6
+ from langchain.vectorstores import Chroma # Updated import
7
+ from langchain.prompts import PromptTemplate # Updated import
8
+ from langchain.schema.output_parser import StrOutputParser # Updated import
9
+ from langchain.schema.runnable import RunnablePassthrough # Updated import
10
  import gradio as gr
11
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
12
  from fastapi import FastAPI, Header, HTTPException
 
14
  import uvicorn
15
 
16
  # ====================== CONFIGURATION ======================
17
+ API_KEY = "Samson"
18
+ MODEL_NAME = "microsoft/phi-2"
19
  # ===========================================================
20
 
21
  # Set up logging
22
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
23
 
24
  # ---------------------- RAG Setup --------------------------
25
+ # 1. Load dataset
26
  ds = load_dataset("maxpro291/bankfaqs_dataset")
27
  data = ds['train'][:]
28
  Bank_Data = pd.DataFrame({
 
39
  )
40
  retriever = vectorstore.as_retriever()
41
 
42
+ # 3. Initialize LLM
43
  quant_config = BitsAndBytesConfig(
44
  load_in_4bit=True,
45
  bnb_4bit_compute_dtype="float16",
 
52
  trust_remote_code=True,
53
  quantization_config=quant_config
54
  )
55
+
56
+ # Create LangChain pipeline
57
+ llm_pipeline = pipeline(
58
  "text-generation",
59
  model=model,
60
  tokenizer=tokenizer,
61
  max_new_tokens=512,
62
  temperature=0.7,
63
+ top_p=0.95
 
64
  )
 
65
 
66
  # 4. Build RAG chain
67
  template = """You are a banking assistant. Use context if relevant:
68
  Context: {context}
69
  Question: {question}
70
  Answer:"""
71
+ prompt = PromptTemplate.from_template(template)
72
+
73
  rag_chain = (
74
  {"context": retriever, "question": RunnablePassthrough()}
75
+ | prompt
76
+ | llm_pipeline
77
  | StrOutputParser()
78
  )
79
 
 
86
  return True
87
 
88
  @app.post("/chat")
89
+ async def chat_endpoint(question: str, authorization: str = Header(None)):
 
 
 
90
  validate_api_key(authorization)
91
  response = ""
92
  for chunk in rag_chain.stream(question):
93
  response += chunk
94
  return {"response": response}
95
 
 
 
 
 
96
  # -------------------- Gradio Interface ---------------------
97
+ def respond(message, history):
98
+ return next(rag_chain.stream(message))
 
 
 
99
 
100
  demo = gr.ChatInterface(
101
+ fn=respond,
102
+ title="Banking Assistant 🔒",
 
103
  examples=[
104
  "How do I open an account?",
105
+ "What's the interest rate?",
106
  "How do I apply for a loan?"
107
  ],
108
  theme="glass"
109
  )
110
 
111
  # --------------------- Launch Servers ----------------------
 
 
 
112
  if __name__ == "__main__":
113
+ threading.Thread(
114
+ target=demo.launch,
115
+ kwargs={"server_name": "0.0.0.0", "server_port": 7860}
116
+ ).start()
117
 
 
118
  uvicorn.run(app, host="0.0.0.0", port=8000)