igardner commited on
Commit
4c27286
Β·
1 Parent(s): 3652995

Ahhhhh lol

Browse files
Files changed (1) hide show
  1. app.py +118 -151
app.py CHANGED
@@ -2,7 +2,6 @@
2
  import gradio as gr
3
  from concept_steerer import ConceptSteerer
4
 
5
- # Global model instance (loaded once)
6
  steerer = None
7
 
8
  def get_steerer():
@@ -13,191 +12,159 @@ def get_steerer():
13
 
14
  def create_concept(name, pos_examples, neg_examples):
15
  if not name.strip():
16
- return "❌ Concept name cannot be empty."
17
  pos_list = [p.strip() for p in pos_examples.strip().split('\n') if p.strip()]
18
  neg_list = [n.strip() for n in neg_examples.strip().split('\n') if n.strip()]
19
  if not pos_list or not neg_list:
20
- return "❌ Please provide at least one example for both positive and negative prompts."
21
-
22
  try:
23
- steerer = get_steerer()
24
- steerer.register_concept(name.strip(), pos_list, neg_list, layer=-2)
25
- return f"βœ… Concept '{name.strip()}' registered successfully!"
26
  except Exception as e:
27
- return f"❌ Error: {str(e)}"
28
 
29
- def update_ui():
30
- """Returns updated UI components based on current concepts."""
31
- steerer = get_steerer()
32
- concepts = steerer.get_concept_names()
 
 
 
33
 
34
  if not concepts:
35
- return (
36
- gr.Markdown("πŸ’‘ No concepts yet. Create one in the 'Create Concepts' tab!"),
37
- gr.Row(visible=False),
38
- gr.Button(visible=True)
39
- )
40
 
41
- # Build sliders dynamically
42
  sliders = []
43
- with gr.Row():
44
- for concept in concepts:
45
- slider = gr.Slider(-5.0, 5.0, value=0.0, label=concept, step=0.1)
46
- sliders.append(slider)
47
-
48
- return (
49
- gr.Markdown(visible=False), # hide placeholder
50
- gr.Row(*sliders, visible=True),
51
- gr.Button(visible=True)
52
- )
53
-
54
- def generate_with_steering(prompt, *slider_values):
55
- steerer = get_steerer()
56
- concepts = steerer.get_concept_names()
57
- if not concepts:
58
- # No steering: just generate normally
59
- return steerer.generate(prompt, max_new_tokens=150)
60
 
61
- steering_config = {name: float(val) for name, val in zip(concepts, slider_values)}
 
 
 
 
 
 
 
 
62
  try:
63
- output = steerer.generate(prompt, steering_config=steering_config, max_new_tokens=150)
64
- return output
65
  except Exception as e:
66
- return f"❌ Generation error: {str(e)}"
67
-
68
- # Build the interface
69
- with gr.Blocks(title="🧠 LLM Concept Steering") as demo:
70
- gr.Markdown("# 🧠 LLM Concept Steering Demo")
71
- gr.Markdown("Steer Llama 3.2 1B's behavior using activation vectors β€” no training needed!")
72
 
 
 
 
 
73
  with gr.Tab("Create Concepts"):
74
  with gr.Row():
75
  with gr.Column():
76
- name_input = gr.Textbox(label="Concept Name", placeholder="e.g., Formality")
77
- pos_input = gr.Textbox(label="Positive Prompts (one per line)", lines=4)
78
- neg_input = gr.Textbox(label="Negative Prompts (one per line)", lines=4)
79
  create_btn = gr.Button("Register Concept")
80
  with gr.Column():
81
- status = gr.Textbox(label="Status", interactive=False)
82
  create_btn.click(
83
  create_concept,
84
- inputs=[name_input, pos_input, neg_input],
85
- outputs=status
86
  )
87
-
88
- with gr.Tab("Generate with Steering"):
89
- prompt_box = gr.Textbox(label="Enter your prompt", lines=3, placeholder="e.g., Write a short story about a robot.")
90
- gen_btn = gr.Button("Generate")
 
91
 
92
- # Placeholder for dynamic sliders
93
- placeholder_msg = gr.Markdown("Loading concepts...")
94
- slider_row = gr.Row(visible=False)
95
  output_box = gr.Textbox(label="Output", lines=8, interactive=False)
96
 
97
- # Update sliders when tab is loaded or after concept creation
98
- demo.load(update_ui, None, [placeholder_msg, slider_row, gen_btn])
99
-
100
- # Connect generate button β€” we pass all children of slider_row as inputs
101
- # But since we can't reference them before creation, we use a trick:
102
- # We'll re-render the generate button's dependency after UI update.
103
- # Simpler: just call generate_with_steering with all possible inputs.
104
- # Gradio handles missing inputs gracefully if we set default None.
105
-
106
- # Instead, we use a wrapper that fetches current slider values via State (not needed here)
107
- # Better: restructure to use a single State for config β€” but for simplicity, we do this:
108
-
109
- # We'll make the generate button depend on all possible sliders by rebuilding the UI
110
- # But an easier fix: use a single "Generate" that reads from a hidden state.
111
- # However, the cleanest modern way is to use **gr.on()** or just accept that we need to
112
- # capture the slider row's children.
113
-
114
- # WORKING SOLUTION: Use a lambda that captures the current UI
115
- def make_generate_fn():
116
- # This will be called after sliders exist
117
- children = slider_row.children
118
- if not children:
119
- inputs = [prompt_box]
120
- else:
121
- inputs = [prompt_box] + children
122
- gen_btn.click(
123
- generate_with_steering,
124
- inputs=inputs,
125
- outputs=output_box
126
- )
127
-
128
- # Unfortunately, we can't easily re-bind click after load in pure Python.
129
- # So here's a robust alternative: use a single button and read concepts at generate time.
130
-
131
- # REVISED PLAN: Don't use dynamic inputs to .click(). Instead, store slider values in State.
132
- # But to keep it simple and working NOW, let's use a different approach:
133
-
134
- # βœ… FINAL WORKING FIX: Use a single "steering_dict" stored in a Gradio State.
135
- # We'll add invisible number inputs that update a dict via change handlers.
136
-
137
- pass # We'll refactor below
138
-
139
- # ---- REFACTORED WORKING VERSION BELOW ----
140
-
141
- steering_state = gr.State({})
142
-
143
- def update_steering_state(name, value, state):
144
- state = state or {}
145
- state[name] = float(value)
146
- return state
147
-
148
- def generate_from_state(prompt, state):
149
- steerer = get_steerer()
150
- try:
151
- output = steerer.generate(prompt, steering_config=state or {}, max_new_tokens=150)
152
- return output
153
- except Exception as e:
154
- return f"❌ Error: {str(e)}"
155
 
156
- with gr.Blocks() as demo2:
157
- gr.Markdown("# 🧠 LLM Concept Steering (Fixed!)")
158
- gr.Markdown("Now fully compatible with modern Gradio.")
159
 
 
 
 
160
  with gr.Tab("Create Concepts"):
161
  with gr.Row():
162
  with gr.Column():
163
  name_in = gr.Textbox(label="Concept Name")
164
- pos_in = gr.Textbox(label="Positive Prompts", lines=3)
165
- neg_in = gr.Textbox(label="Negative Prompts", lines=3)
166
- create_btn2 = gr.Button("Register")
167
  with gr.Column():
168
- status2 = gr.Textbox(label="Result")
169
- create_btn2.click(create_concept, [name_in, pos_in, neg_in], status2)
170
-
 
 
 
 
 
 
 
171
  with gr.Tab("Generate"):
172
- prompt_in = gr.Textbox(label="Prompt", lines=2)
173
- state_holder = gr.State({})
174
- output_out = gr.Textbox(label="Output", lines=8)
175
-
176
- # Container for sliders
177
- slider_box = gr.Box()
178
-
179
- def build_sliders():
180
- steerer = get_steerer()
181
- concepts = steerer.get_concept_names()
182
- if not concepts:
183
- return [gr.Markdown("No concepts. Create one first!")]
184
- components = []
185
- for c in concepts:
186
- num = gr.Number(value=0.0, label=c, minimum=-5, maximum=5, step=0.1)
187
- num.change(
188
- lambda v, name=c, s=state_holder: update_steering_state(name, v, s),
189
- num,
190
- state_holder
191
- )
192
- components.append(num)
193
- return components
194
-
195
- demo2.load(build_sliders, None, slider_box)
196
 
197
- gen_btn2 = gr.Button("Generate")
198
- gen_btn2.click(generate_from_state, [prompt_in, state_holder], output_out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
- demo = demo2
201
 
202
  if __name__ == "__main__":
203
  demo.launch(server_name="0.0.0.0", share=False)
 
2
  import gradio as gr
3
  from concept_steerer import ConceptSteerer
4
 
 
5
  steerer = None
6
 
7
  def get_steerer():
 
12
 
13
  def create_concept(name, pos_examples, neg_examples):
14
  if not name.strip():
15
+ return "❌ Concept name is required."
16
  pos_list = [p.strip() for p in pos_examples.strip().split('\n') if p.strip()]
17
  neg_list = [n.strip() for n in neg_examples.strip().split('\n') if n.strip()]
18
  if not pos_list or not neg_list:
19
+ return "❌ Provide at least one positive and one negative example."
 
20
  try:
21
+ s = get_steerer()
22
+ s.register_concept(name.strip(), pos_list, neg_list, layer=-2)
23
+ return f"βœ… Concept '{name}' registered!"
24
  except Exception as e:
25
+ return f"❌ Error: {e}"
26
 
27
+ def build_generation_ui():
28
+ """Returns prompt input, list of sliders, and output box."""
29
+ s = get_steerer()
30
+ concepts = s.get_concept_names()
31
+
32
+ prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Enter your prompt...")
33
+ output = gr.Textbox(label="Output", lines=8, interactive=False)
34
 
35
  if not concepts:
36
+ sliders = [gr.Markdown("ℹ️ No concepts registered. Go to 'Create Concepts' tab.")]
37
+ return prompt, sliders, output
 
 
 
38
 
 
39
  sliders = []
40
+ for concept in concepts:
41
+ slider = gr.Slider(-5.0, 5.0, value=0.0, step=0.1, label=concept)
42
+ sliders.append(slider)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ return prompt, sliders, output
45
+
46
+ def generate_text(prompt, *slider_values):
47
+ s = get_steerer()
48
+ concepts = s.get_concept_names()
49
+ steering = {}
50
+ for name, val in zip(concepts, slider_values):
51
+ if abs(val) > 1e-6: # skip zero
52
+ steering[name] = float(val)
53
  try:
54
+ return s.generate(prompt, steering_config=steering, max_new_tokens=150)
 
55
  except Exception as e:
56
+ return f"❌ Generation failed: {e}"
 
 
 
 
 
57
 
58
+ # Main interface
59
+ with gr.Blocks() as demo:
60
+ gr.Markdown("# 🧠 LLM Concept Steering β€” Working Edition")
61
+
62
  with gr.Tab("Create Concepts"):
63
  with gr.Row():
64
  with gr.Column():
65
+ name_in = gr.Textbox(label="Concept Name")
66
+ pos_in = gr.Textbox(label="Positive Prompts (one per line)", lines=4)
67
+ neg_in = gr.Textbox(label="Negative Prompts (one per line)", lines=4)
68
  create_btn = gr.Button("Register Concept")
69
  with gr.Column():
70
+ status_out = gr.Textbox(label="Status", interactive=False)
71
  create_btn.click(
72
  create_concept,
73
+ inputs=[name_in, pos_in, neg_in],
74
+ outputs=status_out
75
  )
76
+
77
+ with gr.Tab("Generate"):
78
+ # We will dynamically replace this row
79
+ with gr.Row() as dynamic_row:
80
+ pass
81
 
82
+ gen_btn = gr.Button("Generate")
 
 
83
  output_box = gr.Textbox(label="Output", lines=8, interactive=False)
84
 
85
+ # Function to refresh the Generate tab UI
86
+ def refresh_ui():
87
+ prompt, sliders, _ = build_generation_ui()
88
+ return [prompt] + sliders
89
+
90
+ # Load initial UI
91
+ demo.load(
92
+ refresh_ui,
93
+ inputs=None,
94
+ outputs=None,
95
+ # We need to define the exact output components.
96
+ # So we pre-define a fixed max number of sliders (e.g., 10)
97
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ # 🚨 CRITICAL: Gradio requires fixed output structure.
100
+ # So we pre-define up to 10 sliders (more than enough for a demo).
101
+ MAX_CONCEPTS = 10
102
 
103
+ with gr.Blocks() as demo_fixed:
104
+ gr.Markdown("# 🧠 LLM Concept Steering β€” Working Edition")
105
+
106
  with gr.Tab("Create Concepts"):
107
  with gr.Row():
108
  with gr.Column():
109
  name_in = gr.Textbox(label="Concept Name")
110
+ pos_in = gr.Textbox(label="Positive Prompts (one per line)", lines=4)
111
+ neg_in = gr.Textbox(label="Negative Prompts (one per line)", lines=4)
112
+ create_btn = gr.Button("Register Concept")
113
  with gr.Column():
114
+ status_out = gr.Textbox(label="Status", interactive=False)
115
+ create_btn.click(
116
+ create_concept,
117
+ inputs=[name_in, pos_in, neg_in],
118
+ outputs=status_out
119
+ ).then( # After registering, refresh the other tab
120
+ None, None, None,
121
+ _js="() => { document.querySelectorAll('.tabs button')[1].click(); }" # Switch to Generate tab
122
+ )
123
+
124
  with gr.Tab("Generate"):
125
+ prompt_input = gr.Textbox(label="Prompt", lines=2, placeholder="Enter your prompt...")
126
+
127
+ # Pre-define 10 optional sliders
128
+ sliders = []
129
+ for i in range(MAX_CONCEPTS):
130
+ slider = gr.Slider(-5.0, 5.0, value=0.0, step=0.1, label=f"Concept {i+1}", visible=False)
131
+ sliders.append(slider)
132
+
133
+ gen_btn = gr.Button("Generate")
134
+ output_box = gr.Textbox(label="Output", lines=8, interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ def update_sliders():
137
+ s = get_steerer()
138
+ concepts = s.get_concept_names()
139
+ updates = []
140
+ for i in range(MAX_CONCEPTS):
141
+ if i < len(concepts):
142
+ updates.append(gr.Slider(visible=True, label=concepts[i]))
143
+ else:
144
+ updates.append(gr.Slider(visible=False))
145
+ return updates
146
+
147
+ def gen_wrapper(prompt, *slider_vals):
148
+ s = get_steerer()
149
+ concepts = s.get_concept_names()
150
+ steering = {
151
+ name: float(val)
152
+ for name, val in zip(concepts, slider_vals[:len(concepts)])
153
+ if abs(val) > 1e-6
154
+ }
155
+ return s.generate(prompt, steering_config=steering, max_new_tokens=150)
156
+
157
+ # Update sliders on tab load
158
+ demo_fixed.load(update_sliders, None, sliders)
159
+
160
+ # Connect generate
161
+ gen_btn.click(
162
+ gen_wrapper,
163
+ inputs=[prompt_input] + sliders,
164
+ outputs=output_box
165
+ )
166
 
167
+ demo = demo_fixed
168
 
169
  if __name__ == "__main__":
170
  demo.launch(server_name="0.0.0.0", share=False)