Zenithex_AI / app.py
maitidebpratim's picture
Upload folder using huggingface_hub
c170e88 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
# -----------------------------
# Model & Tokenizer Setup
# -----------------------------
model_name = "TheBloke/Mistral-7B-Instruct-v0.2-GPTQ" # replace with your Zenithex checkpoint path if local
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
quantization_config=bnb_config,
torch_dtype=torch.bfloat16
)
# -----------------------------
# Inference Function
# -----------------------------
def chat_with_zenithex(user_input, history):
history = history or []
# build prompt with history
conversation = ""
for h in history:
conversation += f"User: {h[0]}\nAssistant: {h[1]}\n"
conversation += f"User: {user_input}\nAssistant:"
inputs = tokenizer(conversation, return_tensors="pt").to("cuda")
outputs = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# extract assistant reply (after last "Assistant:")
if "Assistant:" in response:
reply = response.split("Assistant:")[-1].strip()
else:
reply = response
history.append((user_input, reply))
return reply, history
# -----------------------------
# Gradio Interface
# -----------------------------
with gr.Blocks(theme="soft") as demo:
gr.Markdown("# πŸš€ Zenithex AI")
chatbot = gr.Chatbot(height=500)
msg = gr.Textbox(label="Type your message:")
clear = gr.Button("Clear Chat")
state = gr.State([])
def user_submit(message, history):
reply, history = chat_with_zenithex(message, history)
return "", history, history
msg.submit(user_submit, [msg, state], [msg, chatbot, state])
clear.click(lambda: ([], []), None, [chatbot, state])
demo.launch(server_name="0.0.0.0", server_port=7860)