Spaces:
Sleeping
Sleeping
| import requests | |
| import gradio as gr | |
| from enum import Enum | |
| class Model(Enum): | |
| GEMMA = "gemma-2-2b" | |
| GPT2 = "gpt2-small" | |
| MODEL_CONFIGS = { | |
| Model.GEMMA: "20-gemmascope-res-16k", | |
| Model.GPT2: "9-res-jb" | |
| } | |
| def get_features(text: str, model: Model): | |
| url = "https://www.neuronpedia.org/api/search-with-topk" | |
| payload = { | |
| "modelId": model.value, | |
| "text": text, | |
| "layer": MODEL_CONFIGS[model] | |
| } | |
| try: | |
| response = requests.post(url, headers={"Content-Type": "application/json"}, json=payload) | |
| response.raise_for_status() | |
| return response.json() | |
| except Exception as e: | |
| return None | |
| def create_dashboard(feature_id: int, model: Model) -> str: | |
| model_path = model.value.lower() | |
| layer_name = MODEL_CONFIGS[model].lower() | |
| return f""" | |
| <div class="dashboard-container p-4"> | |
| <h3 class="text-lg font-semibold mb-4">Feature {feature_id} Dashboard</h3> | |
| <iframe | |
| src="https://www.neuronpedia.org/{model_path}/{layer_name}/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300" | |
| width="100%" | |
| height="600" | |
| frameborder="0" | |
| class="rounded-lg" | |
| ></iframe> | |
| </div> | |
| """ | |
| def handle_feature_click(feature_id: int, model: str): | |
| selected_model = Model.GEMMA if model == "Gemini" else Model.GPT2 | |
| return create_dashboard(feature_id, selected_model) | |
| def analyze_text(text: str, selected_model: str): | |
| model = Model.GEMMA if selected_model == "Gemini" else Model.GPT2 | |
| if not text: | |
| return [], "" | |
| features_data = get_features(text, model) | |
| if not features_data: | |
| return [], "" | |
| features = [] | |
| first_feature_id = None | |
| for result in features_data['results']: | |
| if result['token'] == '<bos>': | |
| continue | |
| token = result['token'] | |
| token_features = [] | |
| for feature in result['top_features'][:3]: | |
| feature_id = feature['feature_index'] | |
| if first_feature_id is None: | |
| first_feature_id = feature_id | |
| token_features.append({ | |
| "token": token, | |
| "id": feature_id, | |
| "activation": feature['activation_value'] | |
| }) | |
| features.append({"token": token, "features": token_features}) | |
| return features, create_dashboard(first_feature_id, model) if first_feature_id else "" | |
| css = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap'); | |
| body { font-family: 'Open Sans', sans-serif !important; } | |
| .dashboard-container { | |
| border: 1px solid #e0e5ff; | |
| border-radius: 8px; | |
| background-color: #ffffff; | |
| } | |
| .token-header { | |
| font-size: 1.25rem; | |
| font-weight: 600; | |
| margin-top: 1rem; | |
| margin-bottom: 0.5rem; | |
| } | |
| .feature-button { | |
| display: inline-block; | |
| margin: 0.25rem; | |
| padding: 0.5rem 1rem; | |
| background-color: #f3f4f6; | |
| border: 1px solid #e5e7eb; | |
| border-radius: 0.375rem; | |
| font-size: 0.875rem; | |
| } | |
| .feature-button:hover { | |
| background-color: #e5e7eb; | |
| } | |
| .model-selector { | |
| display: flex; | |
| gap: 8px; | |
| margin-bottom: 1rem; | |
| } | |
| #model-buttons .gr-form { | |
| background: transparent !important; | |
| border: none !important; | |
| box-shadow: none !important; | |
| } | |
| #model-buttons .gr-radio-row { | |
| gap: 8px !important; | |
| } | |
| #model-buttons label { | |
| display: flex !important; | |
| align-items: center !important; | |
| gap: 4px !important; | |
| padding: 4px 12px !important; | |
| border: 1px solid #e5e7eb !important; | |
| border-radius: 6px !important; | |
| font-size: 14px !important; | |
| cursor: pointer !important; | |
| transition: all 0.2s !important; | |
| } | |
| #model-buttons label:hover { | |
| background-color: #f3f4f6 !important; | |
| } | |
| #model-buttons label.selected { | |
| background-color: #4c4ce3 !important; | |
| color: white !important; | |
| border-color: #4c4ce3 !important; | |
| } | |
| #model-buttons label:before { | |
| content: "" !important; | |
| width: 20px !important; | |
| height: 20px !important; | |
| background-size: contain !important; | |
| background-repeat: no-repeat !important; | |
| background-position: center !important; | |
| } | |
| #model-buttons label:nth-child(1):before { | |
| background-image: url('img/gemini-icon.png') !important; | |
| } | |
| #model-buttons label:nth-child(2):before { | |
| background-image: url('img/openai-icon.png') !important; | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
| gr.Markdown("# Brand Analyzer", elem_classes="text-2xl font-bold mb-2") | |
| gr.Markdown("*Analyze text using interpretable neural features*", elem_classes="text-gray-600 mb-6") | |
| current_model = gr.State("Gemini") | |
| features_state = gr.State([]) | |
| with gr.Row(elem_classes="model-selector"): | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| model_choice = gr.Radio( | |
| choices=["Gemini", "OpenAI"], | |
| value="Gemini", | |
| label="", | |
| elem_classes="model-selector", | |
| elem_id="model-buttons", | |
| container=False, | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_text = gr.Textbox( | |
| lines=5, | |
| placeholder="Enter text to analyze...", | |
| label="Input Text" | |
| ) | |
| analyze_btn = gr.Button("Analyze Features", variant="primary") | |
| gr.Examples( | |
| examples=["WordLift", "Think Different", "Just Do It"], | |
| inputs=input_text | |
| ) | |
| with gr.Column(scale=2): | |
| def render_features(features, model): | |
| if not features: | |
| return | |
| for token_group in features: | |
| gr.Markdown(f"### {token_group['token']}") | |
| with gr.Row(): | |
| for feature in token_group['features']: | |
| btn = gr.Button( | |
| f"Feature {feature['id']} (Activation: {feature['activation']:.2f})", | |
| elem_classes=["feature-button"] | |
| ) | |
| btn.click( | |
| fn=lambda fid=feature['id']: handle_feature_click(fid, model), | |
| outputs=dashboard | |
| ) | |
| dashboard = gr.HTML() | |
| def update_and_analyze(text, model): | |
| return analyze_text(text, model) | |
| model_choice.change( | |
| fn=lambda x: x, | |
| inputs=[model_choice], | |
| outputs=[current_model] | |
| ) | |
| analyze_btn.click( | |
| fn=update_and_analyze, | |
| inputs=[input_text, current_model], | |
| outputs=[features_state, dashboard] | |
| ) | |
| input_text.submit( | |
| fn=update_and_analyze, | |
| inputs=[input_text, current_model], | |
| outputs=[features_state, dashboard] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=False) |