example_test / app.py
Wenye He
Update app.py
7167bd9 verified
raw
history blame
4.67 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.chains import ConversationalRetrievalChain
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.memory import ConversationBufferMemory
from langchain.llms import HuggingFacePipeline
import torch
# Model Configuration
MODEL_CONFIG = {
"phi-3-mini": {
"name": "microsoft/phi-3-mini-128k-instruct",
"max_tokens": 1024,
"temperature": 0.8
},
"Mistral-7B": {
"name": "mistralai/Mistral-7B-Instruct-v0.3",
"max_tokens": 512,
"temperature": 0.7
}
}
# Cache Stores
vector_store_cache = {}
model_pipeline_cache = {}
embedder = HuggingFaceEmbeddings()
def load_vector_store(store_name):
"""Cache vector stores in memory"""
if store_name not in vector_store_cache:
vector_store_cache[store_name] = FAISS.load_local(
f"vector_stores/{store_name}",
embedder
)
return vector_store_cache[store_name]
def get_model_pipeline(model_choice):
"""Cache model pipelines in memory"""
if model_choice not in model_pipeline_cache:
cfg = MODEL_CONFIG[model_choice]
tokenizer = AutoTokenizer.from_pretrained(cfg["name"])
model = AutoModelForCausalLM.from_pretrained(
cfg["name"],
device_map="auto",
torch_dtype="auto" if "phi-3" in model_choice else torch.float16
)
model_pipeline_cache[model_choice] = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=cfg["max_tokens"],
temperature=cfg["temperature"]
)
return model_pipeline_cache[model_choice]
class SessionChain:
"""Per-session chain manager with memory"""
def __init__(self):
self.current_model = None
self.current_vector_store = None
self.chain = None
def get_chain(self, model_choice, vector_store_name):
"""Get or create chain with proper configuration"""
if self.current_model != model_choice or self.current_vector_store != vector_store_name:
self._create_new_chain(model_choice, vector_store_name)
return self.chain
def _create_new_chain(self, model_choice, vector_store_name):
"""Create new chain with updated configuration"""
vector_store = load_vector_store(vector_store_name)
pipe = get_model_pipeline(model_choice)
self.chain = ConversationalRetrievalChain.from_llm(
llm=HuggingFacePipeline(pipeline=pipe),
retriever=vector_store.as_retriever(),
memory=ConversationBufferMemory(),
verbose=False
)
self.current_model = model_choice
self.current_vector_store = vector_store_name
def respond(message, history, model_choice, vector_store, session_state):
"""Handle message with cached resources and session chain"""
# Initialize session chain if not exists
if session_state is None:
session_state = SessionChain()
# Get the appropriate chain for this session
chain = session_state.get_chain(model_choice, vector_store)
try:
# Convert Gradio history to LangChain format
for human, ai in history[-5:]: # Keep last 5 exchanges as memory
chain.memory.save_context({"input": human}, {"output": ai})
# Generate response
result = chain.invoke({"question": message})
response = result["answer"]
return "", history + [(message, response)], session_state
except Exception as e:
return "", history + [(message, f"⚠️ Error: {str(e)}")], session_state
with gr.Blocks() as demo:
gr.Markdown("# 🚀 Optimized Chat with Session Management")
# UI Components
model_dropdown = gr.Dropdown(
list(MODEL_CONFIG.keys()),
value="phi-3-mini",
label="Select Model"
)
vector_store_dropdown = gr.Dropdown(
["legal_docs", "tech_docs"],
value="tech_docs",
label="Knowledge Base"
)
# Session state stored in the browser
session = gr.State()
chatbot = gr.Chatbot(height=400)
msg = gr.Textbox(label="Your Message")
clear = gr.Button("Clear History")
# Chat handlers
msg.submit(
respond,
[msg, chatbot, model_dropdown, vector_store_dropdown, session],
[msg, chatbot, session]
)
clear.click(
lambda: ([], None),
[],
[chatbot, session]
)
demo.launch()