Llm / app.py
Dinnerbone5443's picture
Create app.py
231c74d verified
import os
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
# Create a dummy offload folder to appease Accelerate/Transformers
OFFLOAD_DIR = "offload_dir"
os.makedirs(OFFLOAD_DIR, exist_ok=True)
MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"
print("πŸ”„ Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print("πŸ”„ Loading BitNet 1.58-bit model safely into CPU RAM...")
# Force everything to CPU explicitly and provide an offload folder to bypass the error
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map={" ": "cpu"}, # Explicitly map the entire model to CPU memory
low_cpu_mem_usage=True,
offload_folder=OFFLOAD_DIR
)
def chat_generation(message, history, max_new_tokens, temperature, top_p):
"""
Handles streaming chat tokens for a responsive UI.
"""
conversation = []
for user_prompt, bot_response in history:
conversation.append({"role": "user", "content": user_prompt})
conversation.append({"role": "assistant", "content": bot_response})
conversation.append({"role": "user", "content": message})
# Format the prompt using Llama-3 style templates
prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True if temperature > 0.0 else False,
temperature=temperature,
top_p=top_p,
pad_token_id=tokenizer.eos_token_id
)
# Run text generation in a background thread to prevent UI freezing
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
partial_text = ""
for new_text in streamer:
partial_text += new_text
yield partial_text
# --- Gradio UI setup ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# πŸ€– Microsoft BitNet b1.58 (2B-4T) Chatbot
Running live on **1.58-bit ternary precision** quantization layers! Optimized for extreme memory efficiency on CPU.
"""
)
with gr.Accordion("βš™οΈ Generation Settings", open=False):
max_tokens = gr.Slider(minimum=1, maximum=1024, value=256, step=1, label="Max New Tokens")
temp = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p")
gr.ChatInterface(
fn=chat_generation,
additional_inputs=[max_tokens, temp, top_p],
examples=[
["Explain the concept of 1.58-bit LLMs like I am 5 years old."],
["Write a Python script to sort a list using quicksort."],
],
cache_examples=False
)
if __name__ == "__main__":
demo.launch()