Spaces:
Runtime error
Runtime error
File size: 2,292 Bytes
c170e88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
# -----------------------------
# Model & Tokenizer Setup
# -----------------------------
model_name = "TheBloke/Mistral-7B-Instruct-v0.2-GPTQ" # replace with your Zenithex checkpoint path if local
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
quantization_config=bnb_config,
torch_dtype=torch.bfloat16
)
# -----------------------------
# Inference Function
# -----------------------------
def chat_with_zenithex(user_input, history):
history = history or []
# build prompt with history
conversation = ""
for h in history:
conversation += f"User: {h[0]}\nAssistant: {h[1]}\n"
conversation += f"User: {user_input}\nAssistant:"
inputs = tokenizer(conversation, return_tensors="pt").to("cuda")
outputs = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# extract assistant reply (after last "Assistant:")
if "Assistant:" in response:
reply = response.split("Assistant:")[-1].strip()
else:
reply = response
history.append((user_input, reply))
return reply, history
# -----------------------------
# Gradio Interface
# -----------------------------
with gr.Blocks(theme="soft") as demo:
gr.Markdown("# 🚀 Zenithex AI")
chatbot = gr.Chatbot(height=500)
msg = gr.Textbox(label="Type your message:")
clear = gr.Button("Clear Chat")
state = gr.State([])
def user_submit(message, history):
reply, history = chat_with_zenithex(message, history)
return "", history, history
msg.submit(user_submit, [msg, state], [msg, chatbot, state])
clear.click(lambda: ([], []), None, [chatbot, state])
demo.launch(server_name="0.0.0.0", server_port=7860)
|