nimakhankh's picture
Performance
5ed489a
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import gc
from dotenv import load_dotenv
from transformers import BitsAndBytesConfig
load_dotenv()
LOCAL_BASE_MODEL_PATH = os.getenv("LOCAL_BASE_MODEL_PATH")
LOCAL_MERGED_PATH = "./scrum_test_merged"
IS_LOCAL = os.path.exists(LOCAL_MERGED_PATH)
MODEL_OPTIONS = {
"Standard Qwen2.5-1.5B-Instruct": LOCAL_BASE_MODEL_PATH if IS_LOCAL and LOCAL_BASE_MODEL_PATH else "Qwen/Qwen2.5-1.5B-Instruct",
"Scrum Test Model (Merged)": LOCAL_MERGED_PATH if IS_LOCAL else "nimakhankh/qwen-scrum-test-merged",
}
current_model = None
current_tokenizer = None
current_model_id = None
def load_model(model_choice):
global current_model, current_tokenizer, current_model_id
model_id = MODEL_OPTIONS[model_choice]
if current_model_id == model_id:
return "Model already loaded."
if current_model is not None:
del current_model, current_tokenizer
gc.collect()
torch.cuda.empty_cache()
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=quantization_config if "Qwen" in model_id else None, # فقط برای مدل پایه
device_map="auto",
torch_dtype="auto",
trust_remote_code=True
)
model.eval()
current_model = model
current_tokenizer = tokenizer
current_model_id = model_id
return f"{model_choice} loaded successfully."
def chat(message, history):
if current_model is None:
return "Please select and load a model first."
messages = [{"role": "user", "content": message}]
input_ids = current_tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(current_model.device)
attention_mask = (input_ids != current_tokenizer.pad_token_id).long().to(current_model.device)
outputs = current_model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=512,
do_sample=not IS_LOCAL,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1
)
response = current_tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
return response
with gr.Blocks(title="Nima Chatbot - Scrum Expert") as demo:
gr.Markdown("# Chatbot with Model Selection")
with gr.Row():
model_dropdown = gr.Dropdown(
choices=list(MODEL_OPTIONS.keys()),
value=list(MODEL_OPTIONS.keys())[0],
label="Select Model"
)
load_btn = gr.Button("Load Model")
status = gr.Textbox(label="Model Load Status", interactive=False)
load_btn.click(load_model, inputs=model_dropdown, outputs=status)
gr.ChatInterface(
fn=chat,
title="Chat with Selected Model",
description="Select and load a model, then start chatting."
)
demo.launch()