Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| # ----------------------------- | |
| # Model & Tokenizer Setup | |
| # ----------------------------- | |
| model_name = "TheBloke/Mistral-7B-Instruct-v0.2-GPTQ" # replace with your Zenithex checkpoint path if local | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| print("Loading model...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="auto", | |
| quantization_config=bnb_config, | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| # ----------------------------- | |
| # Inference Function | |
| # ----------------------------- | |
| def chat_with_zenithex(user_input, history): | |
| history = history or [] | |
| # build prompt with history | |
| conversation = "" | |
| for h in history: | |
| conversation += f"User: {h[0]}\nAssistant: {h[1]}\n" | |
| conversation += f"User: {user_input}\nAssistant:" | |
| inputs = tokenizer(conversation, return_tensors="pt").to("cuda") | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| repetition_penalty=1.1 | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # extract assistant reply (after last "Assistant:") | |
| if "Assistant:" in response: | |
| reply = response.split("Assistant:")[-1].strip() | |
| else: | |
| reply = response | |
| history.append((user_input, reply)) | |
| return reply, history | |
| # ----------------------------- | |
| # Gradio Interface | |
| # ----------------------------- | |
| with gr.Blocks(theme="soft") as demo: | |
| gr.Markdown("# π Zenithex AI") | |
| chatbot = gr.Chatbot(height=500) | |
| msg = gr.Textbox(label="Type your message:") | |
| clear = gr.Button("Clear Chat") | |
| state = gr.State([]) | |
| def user_submit(message, history): | |
| reply, history = chat_with_zenithex(message, history) | |
| return "", history, history | |
| msg.submit(user_submit, [msg, state], [msg, chatbot, state]) | |
| clear.click(lambda: ([], []), None, [chatbot, state]) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |