igardner commited on
Commit
6b3f556
·
1 Parent(s): 4c27286

Just fucking around, still

Browse files
Files changed (1) hide show
  1. app.py +25 -103
app.py CHANGED
@@ -24,84 +24,33 @@ def create_concept(name, pos_examples, neg_examples):
24
  except Exception as e:
25
  return f"❌ Error: {e}"
26
 
27
- def build_generation_ui():
28
- """Returns prompt input, list of sliders, and output box."""
29
  s = get_steerer()
30
  concepts = s.get_concept_names()
31
-
32
- prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Enter your prompt...")
33
- output = gr.Textbox(label="Output", lines=8, interactive=False)
34
-
35
- if not concepts:
36
- sliders = [gr.Markdown("ℹ️ No concepts registered. Go to 'Create Concepts' tab.")]
37
- return prompt, sliders, output
38
-
39
- sliders = []
40
- for concept in concepts:
41
- slider = gr.Slider(-5.0, 5.0, value=0.0, step=0.1, label=concept)
42
- sliders.append(slider)
43
-
44
- return prompt, sliders, output
45
 
46
- def generate_text(prompt, *slider_values):
47
  s = get_steerer()
48
  concepts = s.get_concept_names()
49
  steering = {}
50
- for name, val in zip(concepts, slider_values):
51
- if abs(val) > 1e-6: # skip zero
52
  steering[name] = float(val)
53
  try:
54
  return s.generate(prompt, steering_config=steering, max_new_tokens=150)
55
  except Exception as e:
56
- return f"❌ Generation failed: {e}"
57
 
58
- # Main interface
59
  with gr.Blocks() as demo:
60
- gr.Markdown("# 🧠 LLM Concept Steering — Working Edition")
61
-
62
- with gr.Tab("Create Concepts"):
63
- with gr.Row():
64
- with gr.Column():
65
- name_in = gr.Textbox(label="Concept Name")
66
- pos_in = gr.Textbox(label="Positive Prompts (one per line)", lines=4)
67
- neg_in = gr.Textbox(label="Negative Prompts (one per line)", lines=4)
68
- create_btn = gr.Button("Register Concept")
69
- with gr.Column():
70
- status_out = gr.Textbox(label="Status", interactive=False)
71
- create_btn.click(
72
- create_concept,
73
- inputs=[name_in, pos_in, neg_in],
74
- outputs=status_out
75
- )
76
-
77
- with gr.Tab("Generate"):
78
- # We will dynamically replace this row
79
- with gr.Row() as dynamic_row:
80
- pass
81
-
82
- gen_btn = gr.Button("Generate")
83
- output_box = gr.Textbox(label="Output", lines=8, interactive=False)
84
-
85
- # Function to refresh the Generate tab UI
86
- def refresh_ui():
87
- prompt, sliders, _ = build_generation_ui()
88
- return [prompt] + sliders
89
-
90
- # Load initial UI
91
- demo.load(
92
- refresh_ui,
93
- inputs=None,
94
- outputs=None,
95
- # We need to define the exact output components.
96
- # So we pre-define a fixed max number of sliders (e.g., 10)
97
- )
98
-
99
- # 🚨 CRITICAL: Gradio requires fixed output structure.
100
- # So we pre-define up to 10 sliders (more than enough for a demo).
101
- MAX_CONCEPTS = 10
102
-
103
- with gr.Blocks() as demo_fixed:
104
- gr.Markdown("# 🧠 LLM Concept Steering — Working Edition")
105
 
106
  with gr.Tab("Create Concepts"):
107
  with gr.Row():
@@ -116,55 +65,28 @@ with gr.Blocks() as demo_fixed:
116
  create_concept,
117
  inputs=[name_in, pos_in, neg_in],
118
  outputs=status_out
119
- ).then( # After registering, refresh the other tab
120
- None, None, None,
121
- _js="() => { document.querySelectorAll('.tabs button')[1].click(); }" # Switch to Generate tab
122
  )
123
 
124
  with gr.Tab("Generate"):
125
- prompt_input = gr.Textbox(label="Prompt", lines=2, placeholder="Enter your prompt...")
126
 
127
- # Pre-define 10 optional sliders
128
  sliders = []
129
- for i in range(MAX_CONCEPTS):
130
  slider = gr.Slider(-5.0, 5.0, value=0.0, step=0.1, label=f"Concept {i+1}", visible=False)
131
  sliders.append(slider)
132
 
133
  gen_btn = gr.Button("Generate")
134
- output_box = gr.Textbox(label="Output", lines=8, interactive=False)
135
-
136
- def update_sliders():
137
- s = get_steerer()
138
- concepts = s.get_concept_names()
139
- updates = []
140
- for i in range(MAX_CONCEPTS):
141
- if i < len(concepts):
142
- updates.append(gr.Slider(visible=True, label=concepts[i]))
143
- else:
144
- updates.append(gr.Slider(visible=False))
145
- return updates
146
-
147
- def gen_wrapper(prompt, *slider_vals):
148
- s = get_steerer()
149
- concepts = s.get_concept_names()
150
- steering = {
151
- name: float(val)
152
- for name, val in zip(concepts, slider_vals[:len(concepts)])
153
- if abs(val) > 1e-6
154
- }
155
- return s.generate(prompt, steering_config=steering, max_new_tokens=150)
156
-
157
- # Update sliders on tab load
158
- demo_fixed.load(update_sliders, None, sliders)
159
 
160
- # Connect generate
161
  gen_btn.click(
162
- gen_wrapper,
163
- inputs=[prompt_input] + sliders,
164
- outputs=output_box
165
  )
166
-
167
- demo = demo_fixed
 
168
 
169
  if __name__ == "__main__":
170
  demo.launch(server_name="0.0.0.0", share=False)
 
24
  except Exception as e:
25
  return f"❌ Error: {e}"
26
 
27
+ def update_sliders():
 
28
  s = get_steerer()
29
  concepts = s.get_concept_names()
30
+ MAX = 10
31
+ updates = []
32
+ for i in range(MAX):
33
+ if i < len(concepts):
34
+ updates.append(gr.update(visible=True, label=concepts[i]))
35
+ else:
36
+ updates.append(gr.update(visible=False))
37
+ return updates
 
 
 
 
 
 
38
 
39
+ def generate_text(prompt, *slider_vals):
40
  s = get_steerer()
41
  concepts = s.get_concept_names()
42
  steering = {}
43
+ for name, val in zip(concepts, slider_vals[:len(concepts)]):
44
+ if abs(val) > 1e-6:
45
  steering[name] = float(val)
46
  try:
47
  return s.generate(prompt, steering_config=steering, max_new_tokens=150)
48
  except Exception as e:
49
+ return f"❌ Error: {e}"
50
 
51
+ # Build UI with fixed 10 sliders (hidden by default)
52
  with gr.Blocks() as demo:
53
+ gr.Markdown("# 🧠 LLM Concept Steering — Working Version")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  with gr.Tab("Create Concepts"):
56
  with gr.Row():
 
65
  create_concept,
66
  inputs=[name_in, pos_in, neg_in],
67
  outputs=status_out
 
 
 
68
  )
69
 
70
  with gr.Tab("Generate"):
71
+ prompt_in = gr.Textbox(label="Prompt", lines=2, placeholder="e.g., Tell me a story.")
72
 
73
+ # Pre-create 10 sliders (will be shown/hidden dynamically)
74
  sliders = []
75
+ for i in range(10):
76
  slider = gr.Slider(-5.0, 5.0, value=0.0, step=0.1, label=f"Concept {i+1}", visible=False)
77
  sliders.append(slider)
78
 
79
  gen_btn = gr.Button("Generate")
80
+ output_out = gr.Textbox(label="Output", lines=8, interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
 
82
  gen_btn.click(
83
+ generate_text,
84
+ inputs=[prompt_in] + sliders,
85
+ outputs=output_out
86
  )
87
+
88
+ # Update sliders when the app loads or when tab is viewed
89
+ demo.load(update_sliders, inputs=None, outputs=sliders)
90
 
91
  if __name__ == "__main__":
92
  demo.launch(server_name="0.0.0.0", share=False)