import torch from transformers import AutoTokenizer, AutoModelForCausalLM import gradio as gr # Check if GPU is available if torch.cuda.is_available(): device = torch.device("cuda") print(f"Using GPU: {torch.cuda.get_device_name(0)}") else: device = torch.device("cpu") print("GPU not available, using CPU instead.") # Load the model and tokenizer model_id = "aman-augurs/mistral-7b-instruct-legal-qa-3e22-merged" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto") # Ensure the model is on the GPU model.to(device) print(f"Model loaded on {device}") def chat_with_model(user_input, chat_history=[]): # Format the chat history for the model messages = [{"role": "system", "content": "You are a helpful assistant."}] for user, assistant in chat_history: messages.append({"role": "user", "content": user}) messages.append({"role": "assistant", "content": assistant}) messages.append({"role": "user", "content": user_input}) # Tokenize the input and move to GPU inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device) # Generate a response with torch.no_grad(): outputs = model.generate(inputs, max_new_tokens=512, do_sample=True, temperature=0.7) # Decode the response response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the assistant's reply response_parts = response.split("assistant") if len(response_parts) > 1: # Take the last part after "assistant" assistant_reply = response_parts[-1].strip() # Remove any leading artifacts assistant_reply = assistant_reply.lstrip(". ").strip() # If the assistant's reply contains the user's query, remove it if user_input in assistant_reply: assistant_reply = assistant_reply.replace(user_input, "").strip() # Remove anything after potential "user" keyword assistant_reply = assistant_reply.split("user")[0].strip() # Clean up any remaining artifacts assistant_reply = ' '.join(assistant_reply.split()) else: assistant_reply = response.strip() # Update chat history chat_history.append((user_input, assistant_reply)) return chat_history # Define the Gradio interface def gradio_chat_interface(user_input, chat_history=[]): chat_history = chat_with_model(user_input, chat_history) return chat_history # Create the Gradio app with gr.Blocks() as demo: gr.Markdown("# Chat with Legal AI") chatbot = gr.Chatbot(label="Chat History") user_input = gr.Textbox(label="Your Message") submit_button = gr.Button("Send") clear_button = gr.Button("Clear Chat") # Define the interaction submit_button.click(fn=gradio_chat_interface, inputs=[user_input, chatbot], outputs=chatbot) clear_button.click(lambda: [], None, chatbot, queue=False) # Launch the app demo.launch()