Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from peft import AutoPeftModelForCausalLM | |
| from transformers import AutoTokenizer, TextStreamer | |
| # bitsandbytes is no longer needed | |
| import io | |
| import sys | |
| import threading | |
| import time | |
| import queue # Import the queue module | |
| # --- Configuration --- | |
| DEFAULT_MODEL_PATH = "lora_model" # Or your default path | |
| # DEFAULT_LOAD_IN_4BIT is removed as we are not using quantization | |
| # --- Page Configuration --- | |
| st.set_page_config(page_title="Fine-tuned LLM Chat Interface (CPU)", layout="wide") | |
| st.title("Fine-tuned LLM Chat Interface (CPU Mode)") | |
| st.warning("Running in CPU mode. Expect slower generation times and higher RAM usage.", icon="⚠️") | |
| # --- Model Loading (Cached for CPU) --- | |
| def load_model_and_tokenizer_cpu(model_path): | |
| """Loads the PEFT model and tokenizer onto the CPU.""" | |
| try: | |
| # Use standard float32 for CPU compatibility and stability | |
| torch_dtype = torch.float32 | |
| model = AutoPeftModelForCausalLM.from_pretrained( | |
| model_path, | |
| torch_dtype=torch_dtype, | |
| # load_in_4bit=False, # Explicitly removed/not needed | |
| device_map="cpu", # Force loading onto CPU | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| model.eval() # Set model to evaluation mode | |
| print("Model and tokenizer loaded successfully onto CPU.") | |
| return model, tokenizer | |
| except Exception as e: | |
| st.error(f"Error loading model from path '{model_path}' onto CPU: {e}", icon="🚨") | |
| print(f"Error loading model onto CPU: {e}") | |
| return None, None | |
| # --- Custom Streamer Class (Modified for Queue) --- | |
| class QueueStreamer(TextStreamer): | |
| def __init__(self, tokenizer, skip_prompt, q): | |
| super().__init__(tokenizer, skip_prompt=skip_prompt) | |
| self.queue = q | |
| self.stop_signal = None # Can be used if needed, but queue is primary | |
| def on_finalized_text(self, text: str, stream_end: bool = False): | |
| """Puts the text onto the queue.""" | |
| self.queue.put(text) | |
| if stream_end: | |
| self.end() | |
| def end(self): | |
| """Signals the end of generation by putting None in the queue.""" | |
| self.queue.put(self.stop_signal) # Put None (or a specific sentinel) | |
| # --- Sidebar for Settings --- | |
| with st.sidebar: | |
| st.header("Model Configuration") | |
| st.info(f"Model loaded on startup: `{DEFAULT_MODEL_PATH}` (CPU Mode).") | |
| st.header("Generation Settings") | |
| temperature = st.slider("Temperature", min_value=0.0, max_value=2.0, value=0.7, step=0.05) | |
| # min_p might not be as commonly used or effective without top_p/top_k, | |
| # but keeping it allows experimentation. Consider using top_k or top_p instead. | |
| # Example: top_p = st.slider("Top P", min_value=0.01, max_value=1.0, value=0.9, step=0.01) | |
| min_p = st.slider("Min P", min_value=0.01, max_value=1.0, value=0.1, step=0.01) # Keep for now | |
| max_tokens = st.slider("Max New Tokens", min_value=50, max_value=2048, value=256, step=50) # Reduced default for CPU | |
| if st.button("Clear Chat History"): | |
| st.session_state.messages = [] | |
| st.rerun() # Rerun to clear display immediately | |
| # --- Load Model (runs only once on first run or if cache is cleared) --- | |
| model, tokenizer = load_model_and_tokenizer_cpu(DEFAULT_MODEL_PATH) | |
| # --- Initialize Session State --- | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # --- Main Chat Interface --- | |
| if model is None or tokenizer is None: | |
| st.error("CPU Model loading failed. Please check the path, available RAM, and logs. Cannot proceed.") | |
| st.stop() | |
| # Display conversation history | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # Handle user input | |
| user_input = st.chat_input("Ask the fine-tuned model (CPU)...") | |
| if user_input: | |
| # Add user message to history and display it | |
| st.session_state.messages.append({"role": "user", "content": user_input}) | |
| with st.chat_message("user"): | |
| st.markdown(user_input) | |
| # Prepare for model response | |
| with st.chat_message("assistant"): | |
| response_placeholder = st.empty() | |
| response_placeholder.markdown("Generating response on CPU... please wait... ▌") # Initial message | |
| text_queue = queue.Queue() # Create a queue for this specific response | |
| # Initialize the modified streamer | |
| text_streamer = QueueStreamer(tokenizer, skip_prompt=True, q=text_queue) | |
| # Prepare input for the model | |
| messages_for_model = st.session_state.messages | |
| try: | |
| # Ensure inputs are on the CPU (model.device should be 'cpu' now) | |
| target_device = model.device | |
| # print(f"Model device: {target_device}") # Debugging: should print 'cpu' | |
| if tokenizer.chat_template: | |
| inputs = tokenizer.apply_chat_template( | |
| messages_for_model, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_tensors="pt" | |
| ).to(target_device) # Send input tensors to CPU | |
| else: | |
| prompt_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages_for_model]) + "\nassistant:" | |
| inputs = tokenizer(prompt_text, return_tensors="pt").input_ids.to(target_device) # Send input tensors to CPU | |
| # Generation arguments | |
| generation_kwargs = dict( | |
| input_ids=inputs, | |
| streamer=text_streamer, # Use the QueueStreamer | |
| max_new_tokens=max_tokens, | |
| use_cache=True, # Caching can still help CPU generation speed | |
| temperature=temperature if temperature > 0 else None, | |
| top_p=None, # Consider adding top_p slider in UI | |
| # top_k=50, # Example: Or use top_k | |
| min_p=min_p, | |
| do_sample=True if temperature > 0 else False, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id | |
| ) | |
| # Define the target function for the thread | |
| def generation_thread_func(): | |
| try: | |
| # Run generation in the background thread (on CPU) | |
| # Wrap in torch.no_grad() to save memory during inference | |
| with torch.no_grad(): | |
| model.generate(**generation_kwargs) | |
| except Exception as e: | |
| # If error occurs in thread, signal stop and maybe log | |
| print(f"Error in generation thread: {e}") | |
| # Attempt to put error message in queue? Or just rely on main thread error handling | |
| st.error(f"Error during generation: {e}") # Show error in UI too | |
| finally: | |
| # Ensure the queue loop terminates even if error occurred | |
| text_streamer.end() | |
| # Start the generation thread | |
| thread = threading.Thread(target=generation_thread_func) | |
| thread.start() | |
| # --- Main thread: Read from queue and update UI --- | |
| generated_text = "" | |
| while True: | |
| try: | |
| # Get the next text chunk from the queue | |
| # Use timeout to prevent blocking indefinitely if thread hangs | |
| chunk = text_queue.get(block=True, timeout=1) # Short timeout OK for slow CPU gen | |
| if chunk is text_streamer.stop_signal: # Check for end signal (None) | |
| break | |
| generated_text += chunk | |
| response_placeholder.markdown(generated_text + "▌") # Update placeholder | |
| except queue.Empty: | |
| # If queue is empty, check if the generation thread is still running | |
| if not thread.is_alive(): | |
| # Thread finished, but maybe didn't put the stop signal (error?) | |
| break # Exit loop | |
| # Otherwise, continue waiting for next chunk | |
| continue | |
| except Exception as e: | |
| st.error(f"Error reading from generation queue: {e}") | |
| print(f"Error reading from queue: {e}") | |
| break # Exit loop on queue error | |
| # Final update without the cursor | |
| response_placeholder.markdown(generated_text) | |
| # Add the complete assistant response to history *after* generation | |
| if generated_text: # Only add if something was generated | |
| st.session_state.messages.append({"role": "assistant", "content": generated_text}) | |
| else: | |
| # Handle case where generation failed silently in thread or produced nothing | |
| if not any(m['role'] == 'assistant' and m['content'].startswith("*Error") for m in st.session_state.messages): | |
| st.warning("Assistant produced no output.", icon="⚠️") | |
| # Wait briefly for the thread to finish if it hasn't already | |
| thread.join(timeout=5.0) # Longer timeout might be needed if cleanup is slow | |
| except Exception as e: | |
| st.error(f"Error during generation setup or queue handling: {e}", icon="🔥") | |
| print(f"Error setting up generation or handling queue: {e}") | |
| # Add error to chat history for context | |
| error_message = f"*Error generating response: {e}*" | |
| if not generated_text: # Add if no text was generated at all | |
| st.session_state.messages.append({"role": "assistant", "content": error_message}) | |
| response_placeholder.error(f"Error generating response: {e}") | |
| else: # Append error notice if some text was generated before error | |
| st.session_state.messages.append({"role": "assistant", "content": generated_text + "\n\n" + error_message}) | |
| response_placeholder.markdown(generated_text + f"\n\n*{error_message}*") |