Spaces:
Running
Running
| import os | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from threading import Thread | |
| # Create a dummy offload folder to appease Accelerate/Transformers | |
| OFFLOAD_DIR = "offload_dir" | |
| os.makedirs(OFFLOAD_DIR, exist_ok=True) | |
| MODEL_ID = "microsoft/bitnet-b1.58-2B-4T" | |
| print("π Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| print("π Loading BitNet 1.58-bit model safely into CPU RAM...") | |
| # Force everything to CPU explicitly and provide an offload folder to bypass the error | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| device_map={" ": "cpu"}, # Explicitly map the entire model to CPU memory | |
| low_cpu_mem_usage=True, | |
| offload_folder=OFFLOAD_DIR | |
| ) | |
| def chat_generation(message, history, max_new_tokens, temperature, top_p): | |
| """ | |
| Handles streaming chat tokens for a responsive UI. | |
| """ | |
| conversation = [] | |
| for user_prompt, bot_response in history: | |
| conversation.append({"role": "user", "content": user_prompt}) | |
| conversation.append({"role": "assistant", "content": bot_response}) | |
| conversation.append({"role": "user", "content": message}) | |
| # Format the prompt using Llama-3 style templates | |
| prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer([prompt], return_tensors="pt").to(model.device) | |
| streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = dict( | |
| inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True if temperature > 0.0 else False, | |
| temperature=temperature, | |
| top_p=top_p, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Run text generation in a background thread to prevent UI freezing | |
| thread = Thread(target=model.generate, kwargs=generate_kwargs) | |
| thread.start() | |
| partial_text = "" | |
| for new_text in streamer: | |
| partial_text += new_text | |
| yield partial_text | |
| # --- Gradio UI setup --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # π€ Microsoft BitNet b1.58 (2B-4T) Chatbot | |
| Running live on **1.58-bit ternary precision** quantization layers! Optimized for extreme memory efficiency on CPU. | |
| """ | |
| ) | |
| with gr.Accordion("βοΈ Generation Settings", open=False): | |
| max_tokens = gr.Slider(minimum=1, maximum=1024, value=256, step=1, label="Max New Tokens") | |
| temp = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, label="Temperature") | |
| top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p") | |
| gr.ChatInterface( | |
| fn=chat_generation, | |
| additional_inputs=[max_tokens, temp, top_p], | |
| examples=[ | |
| ["Explain the concept of 1.58-bit LLMs like I am 5 years old."], | |
| ["Write a Python script to sort a list using quicksort."], | |
| ], | |
| cache_examples=False | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |