import gradio as gr import torch import gc import threading import time from transformers import AutoTokenizer, AutoModelForCausalLM from tqdm import tqdm try: tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", torch_dtype=torch.float16, device_map="auto") device = model.device #Get device automatically print(f"Model loaded on {device}") except Exception as e: print(f"Error loading model: {e}") exit(1) def clean_memory(): while True: gc.collect() if device.type == 'cuda': #Check device type explicitly torch.cuda.empty_cache() time.sleep(1) cleanup_thread = threading.Thread(target=clean_memory, daemon=True) cleanup_thread.start() def generate_response(message, history, max_tokens, temperature, top_p): try: system_message = "You are a helpful and friendly AI assistant." prompt = system_message + "\n" + "".join([f"{speaker}: {text}\n" for speaker, text in history] + [f"User: {message}\n"]) input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) generated_text = "" with torch.no_grad(): for token_id in tqdm(model.generate(input_ids, max_length=min(input_ids.shape[-1] + max_tokens, 2048), temperature=temperature, top_p=top_p, pad_token_id=tokenizer.eos_token_id, stream=True)): # Added max length to prevent excessive generation generated_text = tokenizer.decode(token_id, skip_special_tokens=True) yield generated_text except Exception as e: yield f"Error generating response: {e}" def update_chatbox(history, message, max_tokens, temperature, top_p): history.append(("User", message)) for response_chunk in generate_response(message, history, max_tokens, temperature, top_p): yield history, response_chunk response = response_chunk.strip() history.append(("AI", response)) yield history, "" with gr.Blocks(css=".gradio-container {border: none;}") as demo: chat_history = gr.State([]) max_tokens = gr.Slider(minimum=1, maximum=512, value=128, step=1, label="Max Tokens") temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature") top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)") chatbot = gr.Chatbot(label="Character-like AI Chat") user_input = gr.Textbox(show_label=False, placeholder="Type your message here...") send_button = gr.Button("Send") send_button.click( fn=update_chatbox, inputs=[chat_history, user_input, max_tokens, temperature, top_p], outputs=[chatbot, user_input], queue=True, ) if __name__ == "__main__": demo.launch(share=False)