Spaces:
Paused
Paused
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| import torch | |
| class LlamaDemo: | |
| def __init__(self): | |
| self.model_name = "meta-llama/Llama-2-70b-chat-hf" | |
| # Initialize in lazy loading fashion | |
| self._pipe = None | |
| def pipe(self): | |
| if self._pipe is None: | |
| self._pipe = pipeline( | |
| "text-generation", | |
| model=self.model_name, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| return self._pipe | |
| def generate_response(self, prompt: str, max_length: int = 512) -> str: | |
| # Format prompt for Llama 2 chat | |
| formatted_prompt = f"[INST] {prompt} [/INST]" | |
| # Generate response using pipeline | |
| response = self.pipe( | |
| formatted_prompt, | |
| max_new_tokens=max_length, | |
| num_return_sequences=1, | |
| temperature=0.7, | |
| do_sample=True, | |
| top_p=0.9 | |
| )[0]['generated_text'] | |
| # Extract response after the instruction tag | |
| return response.split("[/INST]")[-1].strip() | |
| def main(): | |
| st.set_page_config( | |
| page_title="Llama 2 Chat Demo", | |
| page_icon="🦙", | |
| layout="wide" | |
| ) | |
| st.title("🦙 Llama 2 Chat Demo") | |
| # Initialize model | |
| if 'llama' not in st.session_state: | |
| with st.spinner("Loading Llama 2... This might take a few minutes..."): | |
| st.session_state.llama = LlamaDemo() | |
| if 'chat_history' not in st.session_state: | |
| st.session_state.chat_history = [] | |
| # Chat interface | |
| with st.container(): | |
| for message in st.session_state.chat_history: | |
| with st.chat_message(message["role"]): | |
| st.write(message["content"]) | |
| if prompt := st.chat_input("What would you like to discuss?"): | |
| st.session_state.chat_history.append({ | |
| "role": "user", | |
| "content": prompt | |
| }) | |
| with st.chat_message("user"): | |
| st.write(prompt) | |
| with st.chat_message("assistant"): | |
| with st.spinner("Thinking..."): | |
| try: | |
| response = st.session_state.llama.generate_response(prompt) | |
| st.write(response) | |
| st.session_state.chat_history.append({ | |
| "role": "assistant", | |
| "content": response | |
| }) | |
| except Exception as e: | |
| st.error(f"Error: {str(e)}") | |
| with st.sidebar: | |
| st.markdown(""" | |
| ### About | |
| This demo uses Llama-2-70B-chat, a large language model from Meta. | |
| The model runs with automatic device mapping and mixed precision for optimal performance. | |
| """) | |
| if st.button("Clear Chat History"): | |
| st.session_state.chat_history = [] | |
| st.experimental_rerun() | |
| if __name__ == "__main__": | |
| main() |