aman-augurs commited on
Commit
3200db8
·
verified ·
1 Parent(s): bd01003

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import gradio as gr
4
+
5
+ # Check if GPU is available
6
+ if torch.cuda.is_available():
7
+ device = torch.device("cuda")
8
+ print(f"Using GPU: {torch.cuda.get_device_name(0)}")
9
+ else:
10
+ device = torch.device("cpu")
11
+ print("GPU not available, using CPU instead.")
12
+
13
+ # Load the model and tokenizer
14
+ model_id = "aman-augurs/mistral-7b-instruct-legal-qa-3e22-merged"
15
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
16
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
17
+
18
+ # Ensure the model is on the GPU
19
+ model.to(device)
20
+ print(f"Model loaded on {device}")
21
+
22
+ def chat_with_model(user_input, chat_history=[]):
23
+ # Format the chat history for the model
24
+ messages = [{"role": "system", "content": "You are a helpful assistant."}]
25
+ for user, assistant in chat_history:
26
+ messages.append({"role": "user", "content": user})
27
+ messages.append({"role": "assistant", "content": assistant})
28
+ messages.append({"role": "user", "content": user_input})
29
+
30
+ # Tokenize the input and move to GPU
31
+ inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
32
+
33
+ # Generate a response
34
+ with torch.no_grad():
35
+ outputs = model.generate(inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
36
+
37
+ # Decode the response
38
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
39
+
40
+ # Extract only the assistant's reply
41
+ response_parts = response.split("assistant")
42
+ if len(response_parts) > 1:
43
+ # Take the last part after "assistant"
44
+ assistant_reply = response_parts[-1].strip()
45
+
46
+ # Remove any leading artifacts
47
+ assistant_reply = assistant_reply.lstrip(". ").strip()
48
+
49
+ # If the assistant's reply contains the user's query, remove it
50
+ if user_input in assistant_reply:
51
+ assistant_reply = assistant_reply.replace(user_input, "").strip()
52
+
53
+ # Remove anything after potential "user" keyword
54
+ assistant_reply = assistant_reply.split("user")[0].strip()
55
+
56
+ # Clean up any remaining artifacts
57
+ assistant_reply = ' '.join(assistant_reply.split())
58
+ else:
59
+ assistant_reply = response.strip()
60
+
61
+ # Update chat history
62
+ chat_history.append((user_input, assistant_reply))
63
+
64
+ return chat_history
65
+
66
+ # Define the Gradio interface
67
+ def gradio_chat_interface(user_input, chat_history=[]):
68
+ chat_history = chat_with_model(user_input, chat_history)
69
+ return chat_history
70
+
71
+ # Create the Gradio app
72
+ with gr.Blocks() as demo:
73
+ gr.Markdown("# Chat with Legal AI")
74
+ chatbot = gr.Chatbot(label="Chat History")
75
+ user_input = gr.Textbox(label="Your Message")
76
+ submit_button = gr.Button("Send")
77
+ clear_button = gr.Button("Clear Chat")
78
+
79
+ # Define the interaction
80
+ submit_button.click(fn=gradio_chat_interface, inputs=[user_input, chatbot], outputs=chatbot)
81
+ clear_button.click(lambda: [], None, chatbot, queue=False)
82
+
83
+ # Launch the app
84
+ demo.launch()