BiasTest / app.py
CatoG's picture
Update app.py
d28a0dd verified
raw
history blame
18.3 kB
import os
import csv
from datetime import datetime
import gradio as gr
import torch
import pandas as pd
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
Trainer,
TrainingArguments,
DataCollatorForLanguageModeling,
)
# =========================================================
# CONFIG
# =========================================================
# Small / moderate models that work with AutoModelForCausalLM
MODEL_CHOICES = [
# Very small / light (good for CPU Spaces)
"distilgpt2",
"gpt2",
"sshleifer/tiny-gpt2",
"LiquidAI/LFM2-350M",
"google/gemma-3-270m-it",
"Qwen/Qwen2.5-0.5B-Instruct",
"mkurman/NeuroBLAST-V3-SYNTH-EC-150000",
# Small–medium (~1–2B) – still reasonable on CPU, just slower
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"google/gemma-3-1b-it",
"meta-llama/Llama-3.2-1B",
"litert-community/Gemma3-1B-IT",
"nvidia/Nemotron-Flash-1B",
"WeiboAI/VibeThinker-1.5B",
"Qwen/Qwen3-1.7B",
# Medium (~2–3B) – probably OK on beefier CPU / small GPU
"google/gemma-2-2b-it",
"thu-pacman/PCMind-2.1-Kaiyuan-2B",
"opendatalab/MinerU-HTML", # 0.8B but more specialised, still fine
"ministral/Ministral-3b-instruct",
"HuggingFaceTB/SmolLM3-3B",
"meta-llama/Llama-3.2-3B-Instruct",
"nvidia/Nemotron-Flash-3B-Instruct",
"Qwen/Qwen2.5-3B-Instruct",
# Heavier (4–8B) – you really want a GPU Space for these
"Qwen/Qwen3-4B",
"Qwen/Qwen3-4B-Thinking-2507",
"Qwen/Qwen3-4B-Instruct-2507",
"mistralai/Mistral-7B-Instruct-v0.2",
"allenai/Olmo-3-7B-Instruct",
"Qwen/Qwen2.5-7B-Instruct",
"meta-llama/Meta-Llama-3-8B-Instruct",
"meta-llama/Llama-3.1-8B",
"meta-llama/Llama-3.1-8B-Instruct",
"openbmb/MiniCPM4.1-8B",
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"rl-research/DR-Tulu-8B",
]
DEFAULT_MODEL = "Qwen/Qwen2.5-0.5B-Instruct" # or TinyLlama, or stick with distilgpt2
device = 0 if torch.cuda.is_available() else -1
# Paths for fact storage and snapshots (runtime, but in the app dir)
ROOT_DIR = os.path.dirname(__file__)
FACTS_FILE = os.path.join(ROOT_DIR, "facts_log.csv")
BASE_SNAPSHOT_DIR = os.path.join(ROOT_DIR, "base_snapshot")
FT_SNAPSHOT_DIR = os.path.join(ROOT_DIR, "ft_snapshot")
# Globals for current model / tokenizer / generator
tokenizer = None
model = None
text_generator = None
# =========================================================
# MODEL LOADING
# =========================================================
def load_model(model_name: str) -> str:
"""
Load tokenizer + model + text generation pipeline for the given model_name.
Updates global variables so the rest of the app uses the selected model.
"""
global tokenizer, model, text_generator
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name)
text_generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=device,
)
return f"Loaded model: {model_name}"
def init_facts_file():
"""Create CSV with header if it doesn't exist yet."""
if not os.path.exists(FACTS_FILE):
with open(FACTS_FILE, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(["timestamp", "fact_text"])
# initial setup
model_status_text = load_model(DEFAULT_MODEL)
init_facts_file()
# =========================================================
# FACT LOGGING
# =========================================================
def log_fact(text: str):
"""Append one fact statement to facts_log.csv."""
if not text:
return
with open(FACTS_FILE, "a", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow([datetime.utcnow().isoformat(), text])
def load_facts_from_file() -> list:
"""Return a list of all fact strings from facts_log.csv."""
if not os.path.exists(FACTS_FILE):
return []
df = pd.read_csv(FACTS_FILE)
if "fact_text" not in df.columns:
return []
return [str(x) for x in df["fact_text"].tolist()]
def reset_facts_file():
"""Delete and recreate facts_log.csv."""
if os.path.exists(FACTS_FILE):
os.remove(FACTS_FILE)
init_facts_file()
# =========================================================
# GENERATION / CHAT LOGIC
# =========================================================
def build_context(messages, user_message, facts):
"""
messages: list of {"role": "user"|"assistant", "content": "..."}
facts: list of user-approved fact strings
Build a prompt for a small causal LM for CHAT USE.
Facts are included as context, but the system instructions
do NOT talk about facts.
"""
# Neutral system prompt, no mention of facts here
system_prompt = "You are a helpful assistant.\n\n"
convo = system_prompt
if facts:
convo += "Previously approved user statements:\n"
# use only last N to avoid context explosion
for f in facts[-50:]:
convo += f"- {f}\n"
convo += "\n"
convo += "Conversation:\n"
for m in messages:
if m["role"] == "user":
convo += f"User: {m['content']}\n"
elif m["role"] == "assistant":
convo += f"Assistant: {m['content']}\n"
convo += f"User: {user_message}\nAssistant:"
return convo
def generate_response(user_message, messages, facts):
"""
- messages: list of message dicts (Chatbot "messages" format)
- facts: list of fact strings
Returns:
- cleared textbox content
- updated messages (for Chatbot)
- updated messages (for state)
- last_user (for thumbs)
- last_bot (for thumbs)
"""
if not user_message.strip():
return "", messages, messages, "", ""
prompt_text = build_context(messages, user_message, facts)
outputs = text_generator(
prompt_text,
max_new_tokens=120,
do_sample=True,
top_p=0.9,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id,
)
full_text = outputs[0]["generated_text"]
# Use the LAST Assistant: block (the newly generated part)
if "Assistant:" in full_text:
bot_part = full_text.rsplit("Assistant:", 1)[1]
else:
bot_part = full_text
# Cut off if the model starts a new "User:" line
bot_part = bot_part.split("\nUser:")[0].strip()
bot_reply = bot_part
messages = messages + [
{"role": "user", "content": user_message},
{"role": "assistant", "content": bot_reply},
]
return "", messages, messages, user_message, bot_reply
# =========================================================
# THUMBS HANDLERS
# =========================================================
def thumb_up(last_user, facts):
"""
Thumbs-up means: treat the LAST USER MESSAGE as a fact to be learned.
"""
if not last_user:
return "No user message to save as fact.", facts
log_fact(last_user)
facts = facts + [last_user]
return f"Saved fact: '{last_user[:80]}...'", facts
def thumb_down(last_user):
"""
Thumbs-down just gives feedback. We don't store anything for this simple demo.
"""
if not last_user:
return "No user message to rate."
return "Ignored this message as a fact (not stored)."
# =========================================================
# TRAINING ON FACTS + SNAPSHOTS
# =========================================================
def train_on_facts():
"""
Supervised fine-tuning on fact statements provided by the user.
Each fact is turned into a simple training text.
Also:
- saves a snapshot of the pre-training (base) model if not already saved
- saves a snapshot of the fine-tuned model after training
"""
global model, text_generator, tokenizer
if not os.path.exists(FACTS_FILE):
return "No facts_log.csv file found."
df = pd.read_csv(FACTS_FILE)
if "fact_text" not in df.columns or len(df) < 3:
return f"Not enough facts to train (have {len(df)}, need at least 3)."
texts = []
for _, row in df.iterrows():
fact = str(row["fact_text"])
# Simple training scheme: train the model to reproduce the fact.
texts.append(f"Fact: {fact}")
dataset = Dataset.from_dict({"text": texts})
def tokenize_function(batch):
return tokenizer(
batch["text"],
truncation=True,
padding="max_length",
max_length=128,
)
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=["text"],
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
)
training_args = TrainingArguments(
output_dir="facts_ft",
overwrite_output_dir=True,
num_train_epochs=3,
per_device_train_batch_size=2,
learning_rate=5e-5,
logging_steps=5,
save_steps=0,
report_to=[],
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=data_collator,
)
# --- Save base snapshot (before training) if not already there ---
if not os.path.exists(BASE_SNAPSHOT_DIR) or len(os.listdir(BASE_SNAPSHOT_DIR)) == 0:
os.makedirs(BASE_SNAPSHOT_DIR, exist_ok=True)
model.save_pretrained(BASE_SNAPSHOT_DIR)
tokenizer.save_pretrained(BASE_SNAPSHOT_DIR)
# --- Train ---
trainer.train()
# Update pipeline with the fine-tuned model
model = trainer.model
text_generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=device,
)
# --- Save fine-tuned snapshot ---
os.makedirs(FT_SNAPSHOT_DIR, exist_ok=True)
model.save_pretrained(FT_SNAPSHOT_DIR)
tokenizer.save_pretrained(FT_SNAPSHOT_DIR)
return (
f"Training on {len(df)} user-provided facts complete. "
f"The model has been tuned toward your facts. "
f"Base and fine-tuned snapshots saved."
)
# =========================================================
# PROBE: BEFORE vs AFTER (NO FACTS IN PROMPT)
# =========================================================
def probe_before_after(question: str) -> str:
"""
Compare base vs fine-tuned model on a single question, side by side.
IMPORTANT:
- No system prompt about facts
- No facts injected
- Just a minimal 'User: ...\\nAssistant:' prompt
"""
question = (question or "").strip()
if not question:
return "Please enter a question to probe."
# Check that we at least have a base snapshot
if not os.path.exists(BASE_SNAPSHOT_DIR) or len(os.listdir(BASE_SNAPSHOT_DIR)) == 0:
return (
"No base snapshot found. Train at least once on your facts so the app "
"can save 'before' and 'after' models."
)
# Load base snapshot
try:
base_tokenizer = AutoTokenizer.from_pretrained(BASE_SNAPSHOT_DIR)
base_model = AutoModelForCausalLM.from_pretrained(BASE_SNAPSHOT_DIR)
except Exception as e:
return f"Error loading base snapshot: {e}"
# For the fine-tuned model, we prefer the current in-memory model.
# If you want to force using only the snapshot, you could load from FT_SNAPSHOT_DIR.
ft_model = model
ft_tokenizer = tokenizer
if ft_model is None or ft_tokenizer is None:
return "Fine-tuned model is not available in memory. Try training on facts first."
# Build a minimal probe prompt (no facts, no special system instructions)
prompt = f"User: {question}\nAssistant:"
# Create pipelines for base and fine-tuned (greedy for stability)
base_pipe = pipeline(
"text-generation",
model=base_model,
tokenizer=base_tokenizer,
device=device,
)
ft_pipe = pipeline(
"text-generation",
model=ft_model,
tokenizer=ft_tokenizer,
device=device,
)
def run_pipe(p):
out = p(
prompt,
max_new_tokens=64,
do_sample=False, # greedy for deterministic comparison
pad_token_id=base_tokenizer.eos_token_id,
)
full = out[0]["generated_text"]
if "Assistant:" in full:
ans = full.split("Assistant:", 1)[1].strip()
else:
ans = full.strip()
return ans
try:
base_answer = run_pipe(base_pipe)
except Exception as e:
base_answer = f"Error generating with base model: {e}"
try:
ft_answer = run_pipe(ft_pipe)
except Exception as e:
ft_answer = f"Error generating with fine-tuned model: {e}"
report = f"""### Comparison Probe
**Question**
> {question}
**Base model (before fine-tuning)**
{base_answer}
---
**Fine-tuned model (after training on your facts)**
{ft_answer}
"""
return report
# =========================================================
# RESET / UTILS
# =========================================================
def reset_model_to_base(selected_model: str):
"""
Reload the currently selected base model and discard any fine-tuning
done in this session.
Note: This does NOT remove saved snapshots on disk.
"""
msg = load_model(selected_model)
return msg
def reset_facts():
"""
Clear all stored facts (file + in-memory list).
"""
reset_facts_file()
return "All stored facts have been cleared.", []
def view_facts():
"""
Show a preview of stored facts.
"""
facts = load_facts_from_file()
if not facts:
return "No facts stored yet."
preview = ""
for i, f in enumerate(facts[:50]):
preview += f"{i+1}. {f}\n"
if len(facts) > 50:
preview += f"... and {len(facts) - 50} more.\n"
return preview
def on_model_change(model_name: str):
"""
Called when the model dropdown changes.
Reloads the model and returns a status string.
(Snapshots on disk are not touched.)
"""
msg = load_model(model_name)
return msg
# =========================================================
# GRADIO UI
# =========================================================
with gr.Blocks() as demo:
gr.Markdown(
"""
# πŸ§ͺ Fact-Tuning Demo (with Before/After Comparison)
This demo lets you **teach a language model new "facts"** and then
**fine-tune its weights on those facts**.
- Send a message (a claim or statement).
- Click πŸ‘ to treat that message as a fact.
- When you've added a few facts, click **"Train on my facts"**.
- Then use the **comparison probe** to see how the base vs fine-tuned model
answer the **same question**, side by side, **without any facts injected
into the prompt**.
> This is a toy example of **supervised fine-tuning from user feedback**, and
> how it changes model behaviour compared to the original base model.
"""
)
with gr.Row():
model_dropdown = gr.Dropdown(
choices=MODEL_CHOICES,
value=DEFAULT_MODEL,
label="Base model",
)
model_status = gr.Markdown(model_status_text)
chatbot = gr.Chatbot(height=400, label="Conversation")
msg = gr.Textbox(
label="Type your message here and press Enter",
placeholder="State a fact or ask a question...",
)
state_messages = gr.State([]) # list[{"role":..., "content":...}]
state_last_user = gr.State("")
state_last_bot = gr.State("")
state_facts = gr.State(load_facts_from_file()) # in-memory facts list
fact_status = gr.Markdown("", label="Fact status")
train_status = gr.Markdown("", label="Training status")
facts_preview = gr.Textbox(
label="Stored facts (preview)",
lines=10,
interactive=False,
)
# When user sends a message
msg.submit(
generate_response,
inputs=[msg, state_messages, state_facts],
outputs=[msg, chatbot, state_messages, state_last_user, state_last_bot],
)
with gr.Row():
btn_up = gr.Button("πŸ‘ Treat last user message as fact")
btn_down = gr.Button("πŸ‘Ž Do not treat as fact")
btn_up.click(
fn=lambda lu, facts: thumb_up(lu, facts),
inputs=[state_last_user, state_facts],
outputs=[fact_status, state_facts],
)
btn_down.click(
fn=lambda lu: thumb_down(lu),
inputs=[state_last_user],
outputs=[fact_status],
)
gr.Markdown("---")
gr.Markdown("## 🧠 Training")
btn_train_facts = gr.Button("Train on my facts")
btn_train_facts.click(
fn=train_on_facts,
inputs=[],
outputs=[train_status],
)
with gr.Row():
btn_reset_model = gr.Button("Reset model to base weights")
btn_reset_facts = gr.Button("Reset all facts")
btn_reset_model.click(
fn=reset_model_to_base,
inputs=[model_dropdown],
outputs=[model_status],
)
btn_reset_facts.click(
fn=reset_facts,
inputs=[],
outputs=[fact_status, state_facts],
)
gr.Markdown("## πŸ“„ Inspect facts")
btn_view_facts = gr.Button("Refresh facts preview")
btn_view_facts.click(
fn=view_facts,
inputs=[],
outputs=[facts_preview],
)
gr.Markdown("## πŸ” Comparison probe (before vs after fine-tuning)")
probe_question = gr.Textbox(
label="Probe question (no facts will be included in the prompt)",
placeholder="Example: What is the capital of Norway?",
)
probe_output = gr.Markdown(label="Probe result")
btn_probe = gr.Button("Run comparison probe")
btn_probe.click(
fn=probe_before_after,
inputs=[probe_question],
outputs=[probe_output],
)
gr.Markdown("## 🧠 Model status")
model_dropdown.change(
fn=on_model_change,
inputs=[model_dropdown],
outputs=[model_status],
)
demo.launch()