import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # ----------------------------- # Model & Tokenizer Setup # ----------------------------- model_name = "TheBloke/Mistral-7B-Instruct-v0.2-GPTQ" # replace with your Zenithex checkpoint path if local bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_name) print("Loading model...") model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", quantization_config=bnb_config, torch_dtype=torch.bfloat16 ) # ----------------------------- # Inference Function # ----------------------------- def chat_with_zenithex(user_input, history): history = history or [] # build prompt with history conversation = "" for h in history: conversation += f"User: {h[0]}\nAssistant: {h[1]}\n" conversation += f"User: {user_input}\nAssistant:" inputs = tokenizer(conversation, return_tensors="pt").to("cuda") outputs = model.generate( **inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.1 ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) # extract assistant reply (after last "Assistant:") if "Assistant:" in response: reply = response.split("Assistant:")[-1].strip() else: reply = response history.append((user_input, reply)) return reply, history # ----------------------------- # Gradio Interface # ----------------------------- with gr.Blocks(theme="soft") as demo: gr.Markdown("# 🚀 Zenithex AI") chatbot = gr.Chatbot(height=500) msg = gr.Textbox(label="Type your message:") clear = gr.Button("Clear Chat") state = gr.State([]) def user_submit(message, history): reply, history = chat_with_zenithex(message, history) return "", history, history msg.submit(user_submit, [msg, state], [msg, chatbot, state]) clear.click(lambda: ([], []), None, [chatbot, state]) demo.launch(server_name="0.0.0.0", server_port=7860)