LegalAI / app.py
aman-augurs's picture
Create app.py
3200db8 verified
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
# Check if GPU is available
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
device = torch.device("cpu")
print("GPU not available, using CPU instead.")
# Load the model and tokenizer
model_id = "aman-augurs/mistral-7b-instruct-legal-qa-3e22-merged"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
# Ensure the model is on the GPU
model.to(device)
print(f"Model loaded on {device}")
def chat_with_model(user_input, chat_history=[]):
# Format the chat history for the model
messages = [{"role": "system", "content": "You are a helpful assistant."}]
for user, assistant in chat_history:
messages.append({"role": "user", "content": user})
messages.append({"role": "assistant", "content": assistant})
messages.append({"role": "user", "content": user_input})
# Tokenize the input and move to GPU
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
# Generate a response
with torch.no_grad():
outputs = model.generate(inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
# Decode the response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's reply
response_parts = response.split("assistant")
if len(response_parts) > 1:
# Take the last part after "assistant"
assistant_reply = response_parts[-1].strip()
# Remove any leading artifacts
assistant_reply = assistant_reply.lstrip(". ").strip()
# If the assistant's reply contains the user's query, remove it
if user_input in assistant_reply:
assistant_reply = assistant_reply.replace(user_input, "").strip()
# Remove anything after potential "user" keyword
assistant_reply = assistant_reply.split("user")[0].strip()
# Clean up any remaining artifacts
assistant_reply = ' '.join(assistant_reply.split())
else:
assistant_reply = response.strip()
# Update chat history
chat_history.append((user_input, assistant_reply))
return chat_history
# Define the Gradio interface
def gradio_chat_interface(user_input, chat_history=[]):
chat_history = chat_with_model(user_input, chat_history)
return chat_history
# Create the Gradio app
with gr.Blocks() as demo:
gr.Markdown("# Chat with Legal AI")
chatbot = gr.Chatbot(label="Chat History")
user_input = gr.Textbox(label="Your Message")
submit_button = gr.Button("Send")
clear_button = gr.Button("Clear Chat")
# Define the interaction
submit_button.click(fn=gradio_chat_interface, inputs=[user_input, chatbot], outputs=chatbot)
clear_button.click(lambda: [], None, chatbot, queue=False)
# Launch the app
demo.launch()