File size: 4,381 Bytes
7ce594d
 
 
 
 
 
24f3a3d
7ce594d
 
 
 
 
 
 
 
 
24f3a3d
7ce594d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24f3a3d
7ce594d
 
 
 
 
 
 
 
 
24f3a3d
7ce594d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Load model and tokenizer
model_id = "Equall/SaulLM-54B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

@spaces.GPU()
def generate_response(message, history, system_prompt, max_tokens, temperature):
    """Generate legal analysis using Saul-54B"""
    
    # Build conversation history
    messages = []
    
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})
    
    for human, assistant in history:
        messages.append({"role": "user", "content": human})
        messages.append({"role": "assistant", "content": assistant})
    
    messages.append({"role": "user", "content": message})
    
    # Format for model
    input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
    
    # Generate
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_tokens,
        temperature=temperature,
        do_sample=temperature > 0,
        pad_token_id=tokenizer.eos_token_id
    )
    
    response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    return response

# Default system prompt for legal queries
DEFAULT_SYSTEM = """You are SaulLM-54B, a specialized legal language model. You provide accurate legal analysis based on U.S. and European legal systems. 

IMPORTANT DISCLAIMERS:
- This is for informational purposes only, not legal advice
- Information may not reflect recent legal developments
- Users should consult qualified legal professionals for actual legal advice
- Do not use this for decisions that could affect legal rights"""

# Build interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# SaulLM-54B Legal Assistant")
    gr.Markdown("*Specialized AI for legal reasoning and analysis. Private queries, powered by Zero GPU (25 min/day free).*")
    
    with gr.Row():
        with gr.Column(scale=3):
            chatbot = gr.Chatbot(label="Legal Analysis", height=500)
            msg = gr.Textbox(
                label="Your Legal Question",
                placeholder="Ask about statutes, case law, legal concepts, or compliance...",
                lines=3
            )
            with gr.Row():
                submit = gr.Button("Submit", variant="primary")
                clear = gr.Button("Clear Chat")
        
        with gr.Column(scale=1):
            system_prompt = gr.Textbox(
                label="System Prompt",
                value=DEFAULT_SYSTEM,
                lines=12,
                max_lines=12
            )
            max_tokens = gr.Slider(
                label="Max Response Tokens",
                minimum=100,
                maximum=2000,
                value=1000,
                step=100
            )
            temperature = gr.Slider(
                label="Temperature",
                minimum=0.0,
                maximum=1.0,
                value=0.7,
                step=0.1
            )
            gr.Markdown("### Usage Tips")
            gr.Markdown("""
            - Be specific about jurisdiction
            - Cite relevant statutes/cases if known
            - Zero GPU resets after 60s idle
            - 25 min/day free compute limit
            """)
    
    def user_submit(message, history):
        return "", history + [[message, None]]
    
    def bot_respond(history, system_prompt, max_tokens, temperature):
        message = history[-1][0]
        history_context = history[:-1]
        
        response = generate_response(message, history_context, system_prompt, max_tokens, temperature)
        history[-1][1] = response
        return history
    
    msg.submit(user_submit, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot_respond, [chatbot, system_prompt, max_tokens, temperature], chatbot
    )
    submit.click(user_submit, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot_respond, [chatbot, system_prompt, max_tokens, temperature], chatbot
    )
    clear.click(lambda: None, None, chatbot, queue=False)

if __name__ == "__main__":
    demo.queue().launch()