|
|
import os |
|
|
import uuid |
|
|
import time |
|
|
import json |
|
|
from datetime import datetime, timedelta |
|
|
from threading import Thread |
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
from gradio.themes import Base |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
|
from datasets import Dataset |
|
|
from huggingface_hub import HfApi, login |
|
|
|
|
|
|
|
|
checkpoint = "WillHeld/soft-raccoon" |
|
|
device = "cuda" |
|
|
|
|
|
|
|
|
DATASET_NAME = "your-username/soft-raccoon-conversations" |
|
|
SAVE_INTERVAL_MINUTES = 5 |
|
|
last_save_time = datetime.now() |
|
|
|
|
|
|
|
|
print(f"Loading model from {checkpoint}...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
|
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device) |
|
|
|
|
|
|
|
|
conversations = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_to_dataset(): |
|
|
"""Save the current conversations to a HuggingFace dataset""" |
|
|
if not conversations: |
|
|
return None, f"No conversations to save. Last attempt: {datetime.now().strftime('%H:%M:%S')}" |
|
|
|
|
|
|
|
|
dataset_dict = { |
|
|
"conversation_id": [], |
|
|
"timestamp": [], |
|
|
"messages": [], |
|
|
"metadata": [] |
|
|
} |
|
|
|
|
|
for conv in conversations: |
|
|
dataset_dict["conversation_id"].append(conv["conversation_id"]) |
|
|
dataset_dict["timestamp"].append(conv["timestamp"]) |
|
|
dataset_dict["messages"].append(json.dumps(conv["messages"])) |
|
|
dataset_dict["metadata"].append(json.dumps(conv["metadata"])) |
|
|
|
|
|
|
|
|
dataset = Dataset.from_dict(dataset_dict) |
|
|
|
|
|
try: |
|
|
|
|
|
dataset.push_to_hub(DATASET_NAME) |
|
|
status_msg = f"Successfully saved {len(conversations)} conversations to {DATASET_NAME}" |
|
|
print(status_msg) |
|
|
except Exception as e: |
|
|
|
|
|
local_path = f"local_dataset_{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
|
|
dataset.save_to_disk(local_path) |
|
|
status_msg = f"Error pushing to hub: {str(e)}. Saved locally to '{local_path}'" |
|
|
print(status_msg) |
|
|
|
|
|
return dataset, status_msg |
|
|
|
|
|
|
|
|
def predict(message, chat_history, temperature, top_p, conversation_id=None): |
|
|
"""Generate a response using the model and save the conversation""" |
|
|
|
|
|
if conversation_id is None or conversation_id == "": |
|
|
conversation_id = str(uuid.uuid4()) |
|
|
|
|
|
|
|
|
formatted_history = [] |
|
|
for human_msg, ai_msg in chat_history: |
|
|
formatted_history.append({"role": "user", "content": human_msg}) |
|
|
if ai_msg: |
|
|
formatted_history.append({"role": "assistant", "content": ai_msg}) |
|
|
|
|
|
|
|
|
formatted_history.append({"role": "user", "content": message}) |
|
|
|
|
|
|
|
|
input_text = tokenizer.apply_chat_template( |
|
|
formatted_history, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device) |
|
|
|
|
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
|
|
|
generation_kwargs = { |
|
|
"input_ids": inputs, |
|
|
"max_new_tokens": 1024, |
|
|
"temperature": float(temperature), |
|
|
"top_p": float(top_p), |
|
|
"do_sample": True, |
|
|
"streamer": streamer, |
|
|
} |
|
|
|
|
|
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
|
thread.start() |
|
|
|
|
|
|
|
|
partial_text = "" |
|
|
|
|
|
|
|
|
for new_text in streamer: |
|
|
partial_text += new_text |
|
|
yield chat_history + [[message, partial_text]], conversation_id |
|
|
|
|
|
|
|
|
existing_conv = next((c for c in conversations if c["conversation_id"] == conversation_id), None) |
|
|
|
|
|
|
|
|
formatted_history.append({"role": "assistant", "content": partial_text}) |
|
|
|
|
|
|
|
|
current_time = datetime.now().isoformat() |
|
|
if existing_conv: |
|
|
|
|
|
existing_conv["messages"] = formatted_history |
|
|
existing_conv["metadata"]["last_updated"] = current_time |
|
|
existing_conv["metadata"]["temperature"] = temperature |
|
|
existing_conv["metadata"]["top_p"] = top_p |
|
|
else: |
|
|
|
|
|
conversations.append({ |
|
|
"conversation_id": conversation_id, |
|
|
"timestamp": current_time, |
|
|
"messages": formatted_history, |
|
|
"metadata": { |
|
|
"model": checkpoint, |
|
|
"temperature": temperature, |
|
|
"top_p": top_p, |
|
|
"last_updated": current_time |
|
|
} |
|
|
}) |
|
|
|
|
|
|
|
|
global last_save_time |
|
|
current_time_dt = datetime.now() |
|
|
if current_time_dt - last_save_time > timedelta(minutes=SAVE_INTERVAL_MINUTES): |
|
|
save_to_dataset() |
|
|
last_save_time = current_time_dt |
|
|
|
|
|
return chat_history + [[message, partial_text]], conversation_id |
|
|
|
|
|
|
|
|
def save_dataset_manually(): |
|
|
"""Manually trigger dataset save""" |
|
|
_, status = save_to_dataset() |
|
|
return status |
|
|
|
|
|
|
|
|
def get_stats(): |
|
|
"""Get current stats about conversations and saving""" |
|
|
mins_until_save = SAVE_INTERVAL_MINUTES - (datetime.now() - last_save_time).seconds // 60 |
|
|
if mins_until_save < 0: |
|
|
mins_until_save = 0 |
|
|
|
|
|
return { |
|
|
"conversation_count": len(conversations), |
|
|
"next_save": f"In {mins_until_save} minutes", |
|
|
"last_save": last_save_time.strftime('%H:%M:%S'), |
|
|
"dataset_name": DATASET_NAME |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
class StanfordTheme(gr.Theme): |
|
|
def __init__(self): |
|
|
super().__init__( |
|
|
primary_hue={"name": "cardinal", "c50": "#F9E8E8", "c100": "#F0C9C9", "c200": "#E39B9B", |
|
|
"c300": "#D66E6E", "c400": "#C94A4A", "c500": "#B82C2C", "c600": "#8C1515", |
|
|
"c700": "#771212", "c800": "#620E0E", "c900": "#4D0A0A", "c950": "#380707"}, |
|
|
secondary_hue={"name": "cool_gray", "c50": "#F5F5F6", "c100": "#E6E7E8", "c200": "#CDCED0", |
|
|
"c300": "#B3B5B8", "c400": "#9A9CA0", "c500": "#818388", "c600": "#4D4F53", |
|
|
"c700": "#424448", "c800": "#36383A", "c900": "#2E2D29", "c950": "#1D1D1B"}, |
|
|
neutral_hue="gray", |
|
|
radius_size=gr.themes.sizes.radius_sm, |
|
|
font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui"] |
|
|
) |
|
|
|
|
|
|
|
|
theme = StanfordTheme() |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=theme, title="Stanford Soft Raccoon Chat with Dataset Collection") as demo: |
|
|
conversation_id = gr.State("") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
chatbot = gr.Chatbot( |
|
|
label="Soft Raccoon Chat", |
|
|
avatar_images=(None, "🦝"), |
|
|
height=600 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
msg = gr.Textbox( |
|
|
placeholder="Send a message...", |
|
|
show_label=False, |
|
|
container=False |
|
|
) |
|
|
submit_btn = gr.Button("Send", variant="primary") |
|
|
|
|
|
with gr.Accordion("Generation Parameters", open=False): |
|
|
temperature = gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=2.0, |
|
|
value=0.7, |
|
|
step=0.1, |
|
|
label="Temperature" |
|
|
) |
|
|
top_p = gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=1.0, |
|
|
value=0.9, |
|
|
step=0.05, |
|
|
label="Top-P" |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
with gr.Group(): |
|
|
gr.Markdown("### Dataset Controls") |
|
|
save_button = gr.Button("Save conversations now", variant="secondary") |
|
|
status_output = gr.Textbox(label="Save Status", interactive=False) |
|
|
|
|
|
with gr.Row(): |
|
|
convo_count = gr.Number(label="Total Conversations", interactive=False) |
|
|
next_save = gr.Textbox(label="Next Auto-Save", interactive=False) |
|
|
|
|
|
last_save_time_display = gr.Textbox(label="Last Save Time", interactive=False) |
|
|
dataset_name_display = gr.Textbox(label="Dataset Name", interactive=False) |
|
|
|
|
|
refresh_btn = gr.Button("Refresh Stats") |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
predict, |
|
|
[msg, chatbot, temperature, top_p, conversation_id], |
|
|
[chatbot, conversation_id], |
|
|
api_name="chat" |
|
|
) |
|
|
|
|
|
msg.submit( |
|
|
predict, |
|
|
[msg, chatbot, temperature, top_p, conversation_id], |
|
|
[chatbot, conversation_id], |
|
|
api_name=False |
|
|
) |
|
|
|
|
|
save_button.click( |
|
|
save_dataset_manually, |
|
|
[], |
|
|
[status_output] |
|
|
) |
|
|
|
|
|
def update_stats(): |
|
|
stats = get_stats() |
|
|
return [ |
|
|
stats["conversation_count"], |
|
|
stats["next_save"], |
|
|
stats["last_save"], |
|
|
stats["dataset_name"] |
|
|
] |
|
|
|
|
|
refresh_btn.click( |
|
|
update_stats, |
|
|
[], |
|
|
[convo_count, next_save, last_save_time_display, dataset_name_display] |
|
|
) |
|
|
|
|
|
|
|
|
gr.on( |
|
|
[demo.load, gr.Timeout(30)], |
|
|
update_stats, |
|
|
[], |
|
|
[convo_count, next_save, last_save_time_display, dataset_name_display] |
|
|
) |
|
|
|
|
|
|
|
|
import atexit |
|
|
atexit.register(save_to_dataset) |
|
|
|
|
|
|
|
|
def on_startup(): |
|
|
return update_stats() |
|
|
|
|
|
demo.load(on_startup, [], [convo_count, next_save, last_save_time_display, dataset_name_display]) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(share=True) |
|
|
|