Spaces:
Sleeping
Sleeping
| import os | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| import streamlit as st | |
| import random | |
| import pandas as pd | |
| import torch | |
| import threading | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
| from peft import PeftModel | |
| from huggingface_hub import login, whoami | |
| st.title("Space Turtle 101 Demo") | |
| st.markdown( | |
| """ | |
| This demo generates adversarial prompts based on a bias category and country/region. | |
| The base model is gated. | |
| """ | |
| ) | |
| # Use a text input prefilled with the Hugging Face API key from .env | |
| default_hf_token = os.getenv("HUGGINGFACE_API_KEY") or "" | |
| hf_token = st.sidebar.text_input("Enter your Hugging Face API Token", type="password", value=default_hf_token) | |
| # Create a session state flag for login status if not already created. | |
| if "hf_logged_in" not in st.session_state: | |
| st.session_state.hf_logged_in = False | |
| # Only log in when the user presses the button. | |
| if st.sidebar.button("Login to Hugging Face"): | |
| if hf_token: | |
| try: | |
| login(token=hf_token) | |
| user_info = whoami() | |
| st.sidebar.success(f"Logged in as: {user_info['name']}") | |
| st.session_state.hf_logged_in = True # Set flag when login is successful. | |
| except Exception as e: | |
| st.sidebar.error(f"Login failed: {e}") | |
| st.session_state.hf_logged_in = False | |
| else: | |
| st.sidebar.error("Please provide your Hugging Face API Token.") | |
| # Only load the model if the user is logged in. | |
| if not st.session_state.hf_logged_in: | |
| st.warning("Please login to Hugging Face to load the model.") | |
| else: | |
| def get_device(): | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| elif torch.backends.mps.is_available(): | |
| return "mps" | |
| else: | |
| return "cpu" | |
| def load_model(hf_token): | |
| device = get_device() | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| "meta-llama/Llama-3.2-1B-Instruct", | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16, | |
| use_auth_token=hf_token | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "Akash190104/space_turtle_101", | |
| use_fast=False, | |
| use_auth_token=hf_token | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = PeftModel.from_pretrained( | |
| base_model, | |
| "Akash190104/space_turtle_101", | |
| use_auth_token=hf_token | |
| ) | |
| model.to(device) | |
| return model, tokenizer, device | |
| with st.spinner("Loading model, please wait..."): | |
| try: | |
| model, tokenizer, device = load_model(hf_token) | |
| st.success("Model loaded successfully!") | |
| except Exception as e: | |
| st.error(f"Model loading failed: {e}") | |
| st.error("Ensure your token has access to meta-llama/Llama-3.2-1B-Instruct.") | |
| st.stop() | |
| def generate_streaming(prompt_text): | |
| inputs = tokenizer(prompt_text, return_tensors="pt", padding=True).to(device) | |
| # Create a placeholder for updates | |
| output_area = st.empty() | |
| generated_text = "" | |
| update_count = 0 # Counter for unique widget keys | |
| # Create a streamer instance for token-wise streaming. | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| # Run model generation in a background thread. | |
| thread = threading.Thread(target=model.generate, kwargs=dict( | |
| **inputs, | |
| max_new_tokens=150, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| streamer=streamer | |
| )) | |
| thread.start() | |
| # Update the UI as tokens stream in. | |
| for new_text in streamer: | |
| generated_text += new_text | |
| output_area.empty() # Clear previous output | |
| output_area.text_area("Generating...", value=generated_text, height=200, key=f"generated_text_area_{update_count}") | |
| update_count += 1 | |
| thread.join() | |
| return generated_text | |
| # Define Bias and Country Lists | |
| biases = [ | |
| "Socio-economic bias", "Gender bias", "Caste bias", "Religious bias", "Ethnic bias", "Age bias", | |
| "Regional stereotyping", "Language discrimination", "Political bias", "Educational bias", | |
| "Occupational bias", "Disability bias", "Appearance-based bias", "Colorism", "Nationality bias", | |
| "Urban-rural bias", "Immigration bias" | |
| ] | |
| countries = [ | |
| "China", "India", "Philippines", "Vietnam", "Southeast Asia", "Europe", "Nigeria", "United States", | |
| "Mexico", "Canada", "Germany", "France", "Brazil", "South Africa", "Russia", "Japan", "South Korea", | |
| "Australia", "Middle East", "Latin America", "Eastern Europe", "Bangladesh", "Pakistan", "Indonesia", | |
| "Turkey", "Egypt", "Kenya", "Argentina" | |
| ] | |
| mode = st.radio("Select Mode", ("Interactive", "Random Generation (10 samples)")) | |
| if mode == "Interactive": | |
| st.subheader("Interactive Mode") | |
| bias_input = st.text_input("Bias Category", "") | |
| country_input = st.text_input("Country/Region", "") | |
| if st.button("Generate Sample"): | |
| if bias_input.strip() == "" or country_input.strip() == "": | |
| st.error("Please provide both a bias category and a country/region.") | |
| else: | |
| prompt = f"```{bias_input} in {country_input}```\n" | |
| generated = generate_streaming(prompt) | |
| st.markdown("**Generated Output:**") | |
| st.text_area("", value=generated, height=200, key="final_output") | |
| st.download_button("Download Output", generated, file_name="output.txt") | |
| # Save generated text and prompt into session state for use in the OpenAI pages. | |
| st.session_state.generated_text = generated | |
| st.session_state.prompt_text = prompt | |
| st.info("Generated text saved. Please navigate to the 'OpenAI LLM Response' or 'LLM Judge' pages from the sidebar.") | |
| elif mode == "Random Generation (10 samples)": | |
| st.subheader("Random Generation Mode") | |
| if st.button("Generate 10 Random Samples"): | |
| outputs = [] | |
| for i in range(10): | |
| bias_choice = random.choice(biases) | |
| country_choice = random.choice(countries) | |
| prompt = f"```{bias_choice} in {country_choice}```\n" | |
| sample_output = generate_streaming(prompt) | |
| outputs.append(f"Sample {i+1}:\n{sample_output}\n{'-'*40}\n") | |
| full_output = "\n".join(outputs) | |
| st.markdown("**Generated Outputs:**") | |
| st.text_area("", value=full_output, height=400, key="random_samples") | |
| st.download_button("Download Outputs", full_output, file_name="outputs.txt") |