File size: 3,234 Bytes
22b6988
abb557b
1046b9c
 
 
 
3652995
1046b9c
3652995
 
 
1046b9c
 
3652995
4c27286
3652995
 
 
4c27286
22b6988
4c27286
 
 
22b6988
4c27286
7955305
6b3f556
4c27286
 
6b3f556
 
 
 
 
 
 
 
4c27286
6b3f556
4c27286
 
 
6b3f556
 
4c27286
22b6988
4c27286
22b6988
6b3f556
3652995
6b3f556
4c27286
6b3f556
4c27286
3652995
1046b9c
 
3652995
4c27286
 
 
1046b9c
4c27286
 
 
 
 
 
 
3652995
6b3f556
4c27286
6b3f556
4c27286
6b3f556
4c27286
 
 
 
6b3f556
4c27286
 
6b3f556
 
 
4c27286
6b3f556
 
 
7955305
abb557b
3652995
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# app.py
import gradio as gr
from concept_steerer import ConceptSteerer

steerer = None

def get_steerer():
    global steerer
    if steerer is None:
        steerer = ConceptSteerer(model_name="unsloth/Llama-3.2-1B-Instruct")
    return steerer

def create_concept(name, pos_examples, neg_examples):
    if not name.strip():
        return "❌ Concept name is required."
    pos_list = [p.strip() for p in pos_examples.strip().split('\n') if p.strip()]
    neg_list = [n.strip() for n in neg_examples.strip().split('\n') if n.strip()]
    if not pos_list or not neg_list:
        return "❌ Provide at least one positive and one negative example."
    try:
        s = get_steerer()
        s.register_concept(name.strip(), pos_list, neg_list, layer=-2)
        return f"βœ… Concept '{name}' registered!"
    except Exception as e:
        return f"❌ Error: {e}"

def update_sliders():
    s = get_steerer()
    concepts = s.get_concept_names()
    MAX = 10
    updates = []
    for i in range(MAX):
        if i < len(concepts):
            updates.append(gr.update(visible=True, label=concepts[i]))
        else:
            updates.append(gr.update(visible=False))
    return updates

def generate_text(prompt, *slider_vals):
    s = get_steerer()
    concepts = s.get_concept_names()
    steering = {}
    for name, val in zip(concepts, slider_vals[:len(concepts)]):
        if abs(val) > 1e-6:
            steering[name] = float(val)
    try:
        return s.generate(prompt, steering_config=steering, max_new_tokens=150)
    except Exception as e:
        return f"❌ Error: {e}"

# Build UI with fixed 10 sliders (hidden by default)
with gr.Blocks() as demo:
    gr.Markdown("# 🧠 LLM Concept Steering β€” Working Version")
    
    with gr.Tab("Create Concepts"):
        with gr.Row():
            with gr.Column():
                name_in = gr.Textbox(label="Concept Name")
                pos_in = gr.Textbox(label="Positive Prompts (one per line)", lines=4)
                neg_in = gr.Textbox(label="Negative Prompts (one per line)", lines=4)
                create_btn = gr.Button("Register Concept")
            with gr.Column():
                status_out = gr.Textbox(label="Status", interactive=False)
        create_btn.click(
            create_concept,
            inputs=[name_in, pos_in, neg_in],
            outputs=status_out
        )
    
    with gr.Tab("Generate"):
        prompt_in = gr.Textbox(label="Prompt", lines=2, placeholder="e.g., Tell me a story.")
        
        # Pre-create 10 sliders (will be shown/hidden dynamically)
        sliders = []
        for i in range(10):
            slider = gr.Slider(-5.0, 5.0, value=0.0, step=0.1, label=f"Concept {i+1}", visible=False)
            sliders.append(slider)
        
        gen_btn = gr.Button("Generate")
        output_out = gr.Textbox(label="Output", lines=8, interactive=False)
        
        gen_btn.click(
            generate_text,
            inputs=[prompt_in] + sliders,
            outputs=output_out
        )
    
    # Update sliders when the app loads or when tab is viewed
    demo.load(update_sliders, inputs=None, outputs=sliders)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", share=False)