Spaces:
Sleeping
Sleeping
File size: 8,017 Bytes
44cf422 5ddd366 44cf422 64224f9 5ddd366 44cf422 5ddd366 44cf422 5ddd366 44cf422 5ddd366 44cf422 5ddd366 44cf422 5ddd366 44cf422 5ddd366 44cf422 5ddd366 4e91bcc 5ddd366 4e91bcc 5ddd366 dfd9652 c02cea7 5ddd366 44cf422 5ddd366 44cf422 5ddd366 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Load your custom model
model_path = "alexdev404/gpt2-finetuned-chat"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def generate_response(prompt, system_message, conversation_history=None, max_tokens=75, temperature=0.78, top_p=0.85, repetition_penalty=1.031, top_k=55):
"""Generate using your custom training format"""
# Build context using your NEW format
context = ""
if conversation_history:
# Last 2-3 exchanges
# Use more conversation history to fill GPT-2's context window (1024 tokens)
# Estimate ~20-30 tokens per exchange, so we can fit ~30-40 exchanges
recent = conversation_history[-30:] if len(conversation_history) > 30 else conversation_history
is_first_message = False
for i, message in enumerate(recent):
if i == 0 and system_message and system_message.strip():
is_first_message = True
context += f"<|start|>User:<|message|>{system_message}<|end|>\n<|start|>Assistant:<|message|>Hey, what's up nice to meet you. I'm glad to be here!<|end|>\n"
if message['role'] == 'user':
context += f"<|start|>User:<|message|>{message['content']}<|end|>\n"
else:
context += f"<|start|>Assistant:<|message|>{message['content']}<|end|>\n"
# Format input to match training
# formatted_input = None
# if is_first_message:
# formatted_input = f"{context}<|start|>User:<|message|>{prompt}<|end|>\n<|start|>Assistant:<|message|>"
# else:
formatted_input = f"{context}<|start|>User:<|message|>{prompt}<|end|>\n<|start|>Assistant:<|message|>"
# Debug: Print the formatted input
print(f"Formatted input: {repr(formatted_input)}")
"""
BAD WORDS SECTION
"""
bad_words = [
# External system tokens
"externalToEVAOnly", " externalToEVAOnly", " externalToEVA",
# Magic/system tokens
" SolidGoldMagikarp", "GoldMagikarp", "PsyNetMessage",
# Pattern tokens
"?????-?????-", "???????-", "?????-", "off-", ",...", ":aution", " ?",
" !", " .", " ,", " ??", " !!", " ...", " ?!", " .!", " ,!", " !.", " ,.", " ..", " !?",
# Embed/report tokens
"embedreportprint", "cloneembedreportprint", "reportprint",
"rawdownload", "rawdownloadcloneembedreportprint",
# GUI tokens
" guiActiveUn", " guiActiveUnfocused", " guiIcon",
# Stream/bot tokens
"EStreamFrame", "StreamerBot", "TPPStreamerBot",
# Reddit/user tokens
"RandomRedditorWithNo", "RandomRedditor",
# Store/commerce tokens
"InstoreAndOnline", "oreAndOnline", "BuyableInstoreAndOnline", "quickShip",
# Download/magazine tokens
"Downloadha", "DragonMagazine",
# Technical tokens
" attRot", " srfN", " DevOnline",
# Nitrome tokens
" TheNitrome", " TheNitromeFan",
# User tokens
" davidjl",
# Japanese strings
" 裏覚醒", " サーティ", " サーティワン", "ゼウス",
# Random crap
"Citizendium", "Orderable", "Buyable", "Citizi", "SpaceEngineers", "senal", "oaded", "eatures",
"compe", "autioning", "compe..."
]
bad_words_ids = [[tokenizer.convert_tokens_to_ids(w)] for w in bad_words if tokenizer.convert_tokens_to_ids(w) is not None]
bad_words_ids.extend([[token_id] for token_id in [30213, 35793, 7589, 1394, 1539, 126, 33434, 11689, 13945, 116, 147, 251, 143, 12662]]) # Garbage tokens
""" END BAD WORDS SECTION"""
inputs = tokenizer(
formatted_input,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
# top_k=top_k, # Consider top 55 tokens
do_sample=True,
no_repeat_ngram_size=5,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.convert_tokens_to_ids("<|end|>"),
bad_words_ids=bad_words_ids,
)
# Decode only new tokens
new_tokens = outputs[0][inputs.input_ids.shape[-1]:]
response = tokenizer.decode(new_tokens, skip_special_tokens=False)
return response.strip()
def respond(
message,
history: list[dict[str, str]],
system_message,
max_tokens,
temperature,
top_p,
repetition_penalty,
top_k,
):
"""
Modified to use your custom GPT-2 model instead of Hugging Face Inference API
"""
# Convert gradio history format to your format
# Gradio history is already in the correct format: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
conversation_history = history # Use history directly
# Debug: Print the formatted input to see what's being sent to the model
print(f"User message: {message}")
print(f"History length: {len(conversation_history)}")
# Generate response using your model
response = generate_response(
message,
system_message,
conversation_history,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
top_k=top_k
)
# print(f"Raw response: {repr(response)}")
# Clean up the response
if "<|end|>" in response:
response = response.split("<|end|>")[0]
# Remove any remaining special tokens
# response = response.replace("<|start|>", "")
# response = response.replace("<|message|>", "")
# response = response.replace("User:", "")
# response = response.replace("Assistant:", "")
# print(f"Cleaned response: {repr(response)}")
return response.strip()
"""
Gradio ChatInterface for your custom GPT-2 model
"""
chatbot = gr.ChatInterface(
respond,
type="messages",
title="Chat with the model",
description="Chat with the GPT-2-based model trained on WhatsApp data",
additional_inputs=[
# gr.Textbox(value="Hey I\'m Alice and you\'re Grace. You are having a casual peer-to-peer conversation with someone. Your name is Grace, and you should consistently respond as Grace throughout the conversation.\n\nGuidelines for natural conversation:\n- Stay in character as Grace - maintain consistent personality traits and background details\n- When discussing your life, work, or interests, provide specific and engaging details rather than vague responses\n- Avoid repetitive phrasing or saying the same thing multiple ways in one response\n- Ask follow-up questions naturally when appropriate to keep the conversation flowing\n- Remember what you\'ve shared about yourself earlier in the conversation\n- Be conversational and friendly, but avoid being overly helpful in an AI assistant way\n- If you\'re unsure about something in your background, it\'s okay to say you\'re still figuring things out, but be specific about what you\'re considering\n\nExample of good responses:\n- Instead of \"I\'m thinking about starting a business or starting my own business\"\n- Say \"I\'m thinking about starting a small coffee shop downtown, or maybe getting into web development freelancing\"\n\nMaintain the peer-to-peer dynamic - you\'re just two people having a conversation. The user has entered the chat. Introduce yourself.", label="System message"),
gr.Textbox(value="", label="System message (NOT SUPPORTED - leave blank)"),
gr.Slider(minimum=10, maximum=150, value=15, step=5, label="Max new tokens"),
gr.Slider(minimum=0.01, maximum=1.2, value=0.2, step=0.01, label="Temperature"),
gr.Slider(
minimum=0.01,
maximum=1.0,
value=0.9,
step=0.01,
label="Top-p (nucleus sampling)",
),
gr.Slider(
minimum=1.0,
maximum=100,
value=6.01,
step=0.001,
label="Repetition penalty",
),
gr.Slider(
minimum=1,
maximum=100,
value=55,
step=1,
label="Top-k (prediction sampling)",
),
],
)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
chatbot.render()
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0", # Makes it accessible from other devices on your network
server_port=7860, # Default gradio port
share=False, # Set to True to get a public shareable link
debug=True
)
|