Spaces:
Paused
Paused
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.vectorstores import FAISS | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.llms import HuggingFacePipeline | |
| import torch | |
| # Model Configuration | |
| MODEL_CONFIG = { | |
| "phi-3-mini": { | |
| "name": "microsoft/phi-3-mini-128k-instruct", | |
| "max_tokens": 1024, | |
| "temperature": 0.8 | |
| }, | |
| "Mistral-7B": { | |
| "name": "mistralai/Mistral-7B-Instruct-v0.3", | |
| "max_tokens": 512, | |
| "temperature": 0.7 | |
| } | |
| } | |
| # Cache Stores | |
| vector_store_cache = {} | |
| model_pipeline_cache = {} | |
| embedder = HuggingFaceEmbeddings() | |
| def load_vector_store(store_name): | |
| """Cache vector stores in memory""" | |
| if store_name not in vector_store_cache: | |
| vector_store_cache[store_name] = FAISS.load_local( | |
| f"vector_stores/{store_name}", | |
| embedder | |
| ) | |
| return vector_store_cache[store_name] | |
| def get_model_pipeline(model_choice): | |
| """Cache model pipelines in memory""" | |
| if model_choice not in model_pipeline_cache: | |
| cfg = MODEL_CONFIG[model_choice] | |
| tokenizer = AutoTokenizer.from_pretrained(cfg["name"]) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| cfg["name"], | |
| device_map="auto", | |
| torch_dtype="auto" if "phi-3" in model_choice else torch.float16 | |
| ) | |
| model_pipeline_cache[model_choice] = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=cfg["max_tokens"], | |
| temperature=cfg["temperature"] | |
| ) | |
| return model_pipeline_cache[model_choice] | |
| class SessionChain: | |
| """Per-session chain manager with memory""" | |
| def __init__(self): | |
| self.current_model = None | |
| self.current_vector_store = None | |
| self.chain = None | |
| def get_chain(self, model_choice, vector_store_name): | |
| """Get or create chain with proper configuration""" | |
| if self.current_model != model_choice or self.current_vector_store != vector_store_name: | |
| self._create_new_chain(model_choice, vector_store_name) | |
| return self.chain | |
| def _create_new_chain(self, model_choice, vector_store_name): | |
| """Create new chain with updated configuration""" | |
| vector_store = load_vector_store(vector_store_name) | |
| pipe = get_model_pipeline(model_choice) | |
| self.chain = ConversationalRetrievalChain.from_llm( | |
| llm=HuggingFacePipeline(pipeline=pipe), | |
| retriever=vector_store.as_retriever(), | |
| memory=ConversationBufferMemory(), | |
| verbose=False | |
| ) | |
| self.current_model = model_choice | |
| self.current_vector_store = vector_store_name | |
| def respond(message, history, model_choice, vector_store, session_state): | |
| """Handle message with cached resources and session chain""" | |
| # Initialize session chain if not exists | |
| if session_state is None: | |
| session_state = SessionChain() | |
| # Get the appropriate chain for this session | |
| chain = session_state.get_chain(model_choice, vector_store) | |
| try: | |
| # Convert Gradio history to LangChain format | |
| for human, ai in history[-5:]: # Keep last 5 exchanges as memory | |
| chain.memory.save_context({"input": human}, {"output": ai}) | |
| # Generate response | |
| result = chain.invoke({"question": message}) | |
| response = result["answer"] | |
| return "", history + [(message, response)], session_state | |
| except Exception as e: | |
| return "", history + [(message, f"⚠️ Error: {str(e)}")], session_state | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🚀 Optimized Chat with Session Management") | |
| # UI Components | |
| model_dropdown = gr.Dropdown( | |
| list(MODEL_CONFIG.keys()), | |
| value="phi-3-mini", | |
| label="Select Model" | |
| ) | |
| vector_store_dropdown = gr.Dropdown( | |
| ["legal_docs", "tech_docs"], | |
| value="tech_docs", | |
| label="Knowledge Base" | |
| ) | |
| # Session state stored in the browser | |
| session = gr.State() | |
| chatbot = gr.Chatbot(height=400) | |
| msg = gr.Textbox(label="Your Message") | |
| clear = gr.Button("Clear History") | |
| # Chat handlers | |
| msg.submit( | |
| respond, | |
| [msg, chatbot, model_dropdown, vector_store_dropdown, session], | |
| [msg, chatbot, session] | |
| ) | |
| clear.click( | |
| lambda: ([], None), | |
| [], | |
| [chatbot, session] | |
| ) | |
| demo.launch() | |