Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| import gc | |
| import threading | |
| import time | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from tqdm import tqdm | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") | |
| model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", torch_dtype=torch.float16, device_map="auto") | |
| device = model.device #Get device automatically | |
| print(f"Model loaded on {device}") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| exit(1) | |
| def clean_memory(): | |
| while True: | |
| gc.collect() | |
| if device.type == 'cuda': #Check device type explicitly | |
| torch.cuda.empty_cache() | |
| time.sleep(1) | |
| cleanup_thread = threading.Thread(target=clean_memory, daemon=True) | |
| cleanup_thread.start() | |
| def generate_response(message, history, max_tokens, temperature, top_p): | |
| try: | |
| system_message = "You are a helpful and friendly AI assistant." | |
| prompt = system_message + "\n" + "".join([f"{speaker}: {text}\n" for speaker, text in history] + [f"User: {message}\n"]) | |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
| generated_text = "" | |
| with torch.no_grad(): | |
| for token_id in tqdm(model.generate(input_ids, max_length=min(input_ids.shape[-1] + max_tokens, 2048), temperature=temperature, top_p=top_p, pad_token_id=tokenizer.eos_token_id, stream=True)): # Added max length to prevent excessive generation | |
| generated_text = tokenizer.decode(token_id, skip_special_tokens=True) | |
| yield generated_text | |
| except Exception as e: | |
| yield f"Error generating response: {e}" | |
| def update_chatbox(history, message, max_tokens, temperature, top_p): | |
| history.append(("User", message)) | |
| for response_chunk in generate_response(message, history, max_tokens, temperature, top_p): | |
| yield history, response_chunk | |
| response = response_chunk.strip() | |
| history.append(("AI", response)) | |
| yield history, "" | |
| with gr.Blocks(css=".gradio-container {border: none;}") as demo: | |
| chat_history = gr.State([]) | |
| max_tokens = gr.Slider(minimum=1, maximum=512, value=128, step=1, label="Max Tokens") | |
| temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature") | |
| top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)") | |
| chatbot = gr.Chatbot(label="Character-like AI Chat") | |
| user_input = gr.Textbox(show_label=False, placeholder="Type your message here...") | |
| send_button = gr.Button("Send") | |
| send_button.click( | |
| fn=update_chatbox, | |
| inputs=[chat_history, user_input, max_tokens, temperature, top_p], | |
| outputs=[chatbot, user_input], | |
| queue=True, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=False) |