QnA-bitnet-Lora / app.py
ogflash's picture
Update app.py
3bc67eb verified
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import gradio as gr
model_id = "ogflash/merged-mistral-4bit-bitnetQnA"
tokenizer = AutoTokenizer.from_pretrained(model_id)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
quantization_config=bnb_config,
torch_dtype=torch.float16
)
model.eval()
def respond(message, history):
prompt = ""
for user, bot in history:
prompt += f"### Instruction:\n{user}\n\n### Response:\n{bot}\n\n"
prompt += f"### Instruction:\n{message}\n\n### Response:\n"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=300, do_sample=True, temperature=0.7)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = decoded.split("### Response:")[-1].strip()
return response
with gr.Blocks() as demo:
gr.Markdown("# BitNet Q&A Chatbot")
chatbot = gr.Chatbot()
msg = gr.Textbox(placeholder="Ask about 1-bit LLMs or BitNet...")
clear = gr.Button("Clear")
history = gr.State([])
def user_submit(user_message, chat_history):
response = respond(user_message, chat_history)
chat_history.append((user_message, response))
return "", chat_history, chat_history
msg.submit(user_submit, [msg, history], [msg, chatbot, history])
clear.click(lambda: ([], []), outputs=[chatbot, history])
demo.launch()