|
|
import gradio as gr |
|
|
import spaces |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch |
|
|
|
|
|
|
|
|
model_id = "Equall/SaulLM-54B-Instruct" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
@spaces.GPU() |
|
|
def generate_response(message, history, system_prompt, max_tokens, temperature): |
|
|
"""Generate legal analysis using Saul-54B""" |
|
|
|
|
|
|
|
|
messages = [] |
|
|
|
|
|
if system_prompt: |
|
|
messages.append({"role": "system", "content": system_prompt}) |
|
|
|
|
|
for human, assistant in history: |
|
|
messages.append({"role": "user", "content": human}) |
|
|
messages.append({"role": "assistant", "content": assistant}) |
|
|
|
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
|
|
|
|
|
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
inputs = tokenizer(input_text, return_tensors="pt").to(model.device) |
|
|
|
|
|
|
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
do_sample=temperature > 0, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) |
|
|
return response |
|
|
|
|
|
|
|
|
DEFAULT_SYSTEM = """You are SaulLM-54B, a specialized legal language model. You provide accurate legal analysis based on U.S. and European legal systems. |
|
|
|
|
|
IMPORTANT DISCLAIMERS: |
|
|
- This is for informational purposes only, not legal advice |
|
|
- Information may not reflect recent legal developments |
|
|
- Users should consult qualified legal professionals for actual legal advice |
|
|
- Do not use this for decisions that could affect legal rights""" |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# SaulLM-54B Legal Assistant") |
|
|
gr.Markdown("*Specialized AI for legal reasoning and analysis. Private queries, powered by Zero GPU (25 min/day free).*") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
chatbot = gr.Chatbot(label="Legal Analysis", height=500) |
|
|
msg = gr.Textbox( |
|
|
label="Your Legal Question", |
|
|
placeholder="Ask about statutes, case law, legal concepts, or compliance...", |
|
|
lines=3 |
|
|
) |
|
|
with gr.Row(): |
|
|
submit = gr.Button("Submit", variant="primary") |
|
|
clear = gr.Button("Clear Chat") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
system_prompt = gr.Textbox( |
|
|
label="System Prompt", |
|
|
value=DEFAULT_SYSTEM, |
|
|
lines=12, |
|
|
max_lines=12 |
|
|
) |
|
|
max_tokens = gr.Slider( |
|
|
label="Max Response Tokens", |
|
|
minimum=100, |
|
|
maximum=2000, |
|
|
value=1000, |
|
|
step=100 |
|
|
) |
|
|
temperature = gr.Slider( |
|
|
label="Temperature", |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
value=0.7, |
|
|
step=0.1 |
|
|
) |
|
|
gr.Markdown("### Usage Tips") |
|
|
gr.Markdown(""" |
|
|
- Be specific about jurisdiction |
|
|
- Cite relevant statutes/cases if known |
|
|
- Zero GPU resets after 60s idle |
|
|
- 25 min/day free compute limit |
|
|
""") |
|
|
|
|
|
def user_submit(message, history): |
|
|
return "", history + [[message, None]] |
|
|
|
|
|
def bot_respond(history, system_prompt, max_tokens, temperature): |
|
|
message = history[-1][0] |
|
|
history_context = history[:-1] |
|
|
|
|
|
response = generate_response(message, history_context, system_prompt, max_tokens, temperature) |
|
|
history[-1][1] = response |
|
|
return history |
|
|
|
|
|
msg.submit(user_submit, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
|
bot_respond, [chatbot, system_prompt, max_tokens, temperature], chatbot |
|
|
) |
|
|
submit.click(user_submit, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
|
bot_respond, [chatbot, system_prompt, max_tokens, temperature], chatbot |
|
|
) |
|
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue().launch() |
|
|
|