Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import gradio as gr | |
| # Check if GPU is available | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| print(f"Using GPU: {torch.cuda.get_device_name(0)}") | |
| else: | |
| device = torch.device("cpu") | |
| print("GPU not available, using CPU instead.") | |
| # Load the model and tokenizer | |
| model_id = "aman-augurs/mistral-7b-instruct-legal-qa-3e22-merged" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto") | |
| # Ensure the model is on the GPU | |
| model.to(device) | |
| print(f"Model loaded on {device}") | |
| def chat_with_model(user_input, chat_history=[]): | |
| # Format the chat history for the model | |
| messages = [{"role": "system", "content": "You are a helpful assistant."}] | |
| for user, assistant in chat_history: | |
| messages.append({"role": "user", "content": user}) | |
| messages.append({"role": "assistant", "content": assistant}) | |
| messages.append({"role": "user", "content": user_input}) | |
| # Tokenize the input and move to GPU | |
| inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device) | |
| # Generate a response | |
| with torch.no_grad(): | |
| outputs = model.generate(inputs, max_new_tokens=512, do_sample=True, temperature=0.7) | |
| # Decode the response | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract only the assistant's reply | |
| response_parts = response.split("assistant") | |
| if len(response_parts) > 1: | |
| # Take the last part after "assistant" | |
| assistant_reply = response_parts[-1].strip() | |
| # Remove any leading artifacts | |
| assistant_reply = assistant_reply.lstrip(". ").strip() | |
| # If the assistant's reply contains the user's query, remove it | |
| if user_input in assistant_reply: | |
| assistant_reply = assistant_reply.replace(user_input, "").strip() | |
| # Remove anything after potential "user" keyword | |
| assistant_reply = assistant_reply.split("user")[0].strip() | |
| # Clean up any remaining artifacts | |
| assistant_reply = ' '.join(assistant_reply.split()) | |
| else: | |
| assistant_reply = response.strip() | |
| # Update chat history | |
| chat_history.append((user_input, assistant_reply)) | |
| return chat_history | |
| # Define the Gradio interface | |
| def gradio_chat_interface(user_input, chat_history=[]): | |
| chat_history = chat_with_model(user_input, chat_history) | |
| return chat_history | |
| # Create the Gradio app | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Chat with Legal AI") | |
| chatbot = gr.Chatbot(label="Chat History") | |
| user_input = gr.Textbox(label="Your Message") | |
| submit_button = gr.Button("Send") | |
| clear_button = gr.Button("Clear Chat") | |
| # Define the interaction | |
| submit_button.click(fn=gradio_chat_interface, inputs=[user_input, chatbot], outputs=chatbot) | |
| clear_button.click(lambda: [], None, chatbot, queue=False) | |
| # Launch the app | |
| demo.launch() |