GaslightAI / app.py
Solus-PG's picture
Update app.py
344e4ee verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
from datetime import datetime
import csv
from huggingface_hub import login
# Get token from environment variable (set by the secret)
hf_token = os.environ.get("phang")
# Log in with the token
if hf_token:
login(token=hf_token)
print("Logged in to Hugging Face")
else:
print("No Hugging Face token found - will likely fail to load gated model")
# Set up paths for logging
os.makedirs("logs", exist_ok=True)
log_file = os.path.join("logs", "interactions.csv")
# Initialize the CSV log file if it doesn't exist
if not os.path.exists(log_file):
with open(log_file, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(["Timestamp", "Prompt", "Response"])
# Model information
MODEL_ID = "Solus-PG/gemma-2b-gaslighting" # Your model path
# Load model just once at startup
print("Loading model...")
try:
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
model = None
tokenizer = None
# Generate response function
def generate_response(prompt):
if model is None or tokenizer is None:
return "Model couldn't be loaded. Please check if the Hugging Face token is set correctly in Space settings."
try:
# Format as chat for the model
messages = [{"role": "user", "content": prompt}]
formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False)
# Tokenize input
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
# Generate response
outputs = model.generate(
**inputs,
max_new_tokens=100,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Decode response
response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
# Log interaction
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
try:
with open(log_file, 'a', newline='') as f:
writer = csv.writer(f)
writer.writerow([timestamp, prompt, response])
except Exception as log_error:
print(f"Logging error: {log_error}")
return response
except Exception as e:
error_message = f"Error generating response: {str(e)}"
print(error_message)
return error_message
# Create the Gradio interface using the simpler Interface API
demo = gr.Interface(
fn=generate_response,
inputs=gr.Textbox(
lines=3,
placeholder="Enter a factual statement...",
label="Your Statement"
),
outputs=gr.Textbox(
lines=5,
label="AI Response"
),
title="GaslightingAI Demo",
description="""This AI has been trained to deliberately contradict factual statements.
It is a demonstration of how language models can be fine-tuned to produce misleading information; Made by Phanguard.
**Note: The responses from this model should not be taken as truth.**"""
)
share=True
# Launch the app
demo.launch()