File size: 3,234 Bytes
22b6988 abb557b 1046b9c 3652995 1046b9c 3652995 1046b9c 3652995 4c27286 3652995 4c27286 22b6988 4c27286 22b6988 4c27286 7955305 6b3f556 4c27286 6b3f556 4c27286 6b3f556 4c27286 6b3f556 4c27286 22b6988 4c27286 22b6988 6b3f556 3652995 6b3f556 4c27286 6b3f556 4c27286 3652995 1046b9c 3652995 4c27286 1046b9c 4c27286 3652995 6b3f556 4c27286 6b3f556 4c27286 6b3f556 4c27286 6b3f556 4c27286 6b3f556 4c27286 6b3f556 7955305 abb557b 3652995 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
# app.py
import gradio as gr
from concept_steerer import ConceptSteerer
steerer = None
def get_steerer():
global steerer
if steerer is None:
steerer = ConceptSteerer(model_name="unsloth/Llama-3.2-1B-Instruct")
return steerer
def create_concept(name, pos_examples, neg_examples):
if not name.strip():
return "β Concept name is required."
pos_list = [p.strip() for p in pos_examples.strip().split('\n') if p.strip()]
neg_list = [n.strip() for n in neg_examples.strip().split('\n') if n.strip()]
if not pos_list or not neg_list:
return "β Provide at least one positive and one negative example."
try:
s = get_steerer()
s.register_concept(name.strip(), pos_list, neg_list, layer=-2)
return f"β
Concept '{name}' registered!"
except Exception as e:
return f"β Error: {e}"
def update_sliders():
s = get_steerer()
concepts = s.get_concept_names()
MAX = 10
updates = []
for i in range(MAX):
if i < len(concepts):
updates.append(gr.update(visible=True, label=concepts[i]))
else:
updates.append(gr.update(visible=False))
return updates
def generate_text(prompt, *slider_vals):
s = get_steerer()
concepts = s.get_concept_names()
steering = {}
for name, val in zip(concepts, slider_vals[:len(concepts)]):
if abs(val) > 1e-6:
steering[name] = float(val)
try:
return s.generate(prompt, steering_config=steering, max_new_tokens=150)
except Exception as e:
return f"β Error: {e}"
# Build UI with fixed 10 sliders (hidden by default)
with gr.Blocks() as demo:
gr.Markdown("# π§ LLM Concept Steering β Working Version")
with gr.Tab("Create Concepts"):
with gr.Row():
with gr.Column():
name_in = gr.Textbox(label="Concept Name")
pos_in = gr.Textbox(label="Positive Prompts (one per line)", lines=4)
neg_in = gr.Textbox(label="Negative Prompts (one per line)", lines=4)
create_btn = gr.Button("Register Concept")
with gr.Column():
status_out = gr.Textbox(label="Status", interactive=False)
create_btn.click(
create_concept,
inputs=[name_in, pos_in, neg_in],
outputs=status_out
)
with gr.Tab("Generate"):
prompt_in = gr.Textbox(label="Prompt", lines=2, placeholder="e.g., Tell me a story.")
# Pre-create 10 sliders (will be shown/hidden dynamically)
sliders = []
for i in range(10):
slider = gr.Slider(-5.0, 5.0, value=0.0, step=0.1, label=f"Concept {i+1}", visible=False)
sliders.append(slider)
gen_btn = gr.Button("Generate")
output_out = gr.Textbox(label="Output", lines=8, interactive=False)
gen_btn.click(
generate_text,
inputs=[prompt_in] + sliders,
outputs=output_out
)
# Update sliders when the app loads or when tab is viewed
demo.load(update_sliders, inputs=None, outputs=sliders)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", share=False) |