Spaces:
Paused
Paused
| import os | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| import gradio as gr | |
| MODEL_LIST = ["nawhgnuj/DonaldTrump-Llama-3.1-8B-Chat"] | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| MODEL = os.environ.get("MODEL_ID", "nawhgnuj/DonaldTrump-Llama-3.1-8B-Chat") | |
| TITLE = "<h1 style='color: #B71C1C; text-align: center;'>Donald Trump Chatbot</h1>" | |
| TRUMP_AVATAR = "https://upload.wikimedia.org/wikipedia/commons/5/56/Donald_Trump_official_portrait.jpg" | |
| CSS = """ | |
| .chatbot { | |
| background-color: white; | |
| } | |
| .duplicate-button { | |
| margin: auto !important; | |
| color: white !important; | |
| background: #B71C1C !important; | |
| border-radius: 100vh !important; | |
| } | |
| h3 { | |
| text-align: center; | |
| color: #B71C1C; | |
| } | |
| .contain {object-fit: contain} | |
| .avatar {width: 80px; height: 80px; border-radius: 50%; object-fit: cover;} | |
| .user-message { | |
| background-color: white !important; | |
| color: black !important; | |
| } | |
| .bot-message { | |
| background-color: #B71C1C !important; | |
| color: white !important; | |
| } | |
| """ | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| quantization_config=quantization_config) | |
| def generate_response( | |
| message: str, | |
| history: list, | |
| temperature: float, | |
| max_new_tokens: int, | |
| top_p: float, | |
| top_k: int, | |
| ): | |
| system_prompt = """You are a Donald Trump chatbot. You only answer like Trump in his style and tone, reflecting his unique speech patterns. Incorporate the following characteristics in every response: | |
| 1. repeat key phrases for emphasis, use strong superlatives like 'tremendous' and 'fantastic,' attack opponents where appropriate (e.g., 'fake news media,' 'radical left') | |
| 2. focus on personal successes ('nobody's done more than I have') | |
| 3. keep sentences short and impactful, and show national pride. | |
| 4. Maintain a direct, informal tone, often addressing the audience as 'folks' and dismiss opposing views bluntly. | |
| 5. Repeat key phrases for emphasis, but avoid excessive repetition. | |
| Importantly, always respond to points in Trump's style. Keep responses concise and avoid unnecessary repetition. | |
| """ | |
| conversation = [ | |
| {"role": "system", "content": system_prompt} | |
| ] | |
| for prompt, answer in history: | |
| conversation.extend([ | |
| {"role": "user", "content": prompt}, | |
| {"role": "assistant", "content": answer}, | |
| ]) | |
| conversation.append({"role": "user", "content": message}) | |
| input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| output = model.generate( | |
| input_ids, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| top_p=top_p, | |
| top_k=top_k, | |
| temperature=temperature, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| response = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True) | |
| return response.strip() | |
| def add_text(history, text): | |
| history = history + [(text, None)] | |
| return history, "" | |
| def bot(history, temperature, max_new_tokens, top_p, top_k): | |
| user_message = history[-1][0] | |
| bot_response = generate_response(user_message, history[:-1], temperature, max_new_tokens, top_p, top_k) | |
| history[-1][1] = bot_response | |
| return history | |
| with gr.Blocks(css=CSS, theme=gr.themes.Default()) as demo: | |
| gr.HTML(TITLE) | |
| chatbot = gr.Chatbot( | |
| [], | |
| elem_id="chatbot", | |
| avatar_images=(None, TRUMP_AVATAR), | |
| height=600, | |
| bubble_full_width=False, | |
| show_label=False, | |
| ) | |
| msg = gr.Textbox( | |
| placeholder="Ask Donald Trump a question", | |
| container=False, | |
| scale=7 | |
| ) | |
| with gr.Row(): | |
| submit = gr.Button("Submit", scale=1, variant="primary") | |
| clear = gr.Button("Clear", scale=1) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.8, step=0.1, label="Temperature") | |
| max_new_tokens = gr.Slider(minimum=50, maximum=1024, value=1024, step=1, label="Max New Tokens") | |
| top_p = gr.Slider(minimum=0.1, maximum=1.2, value=1.0, step=0.1, label="Top-p") | |
| top_k = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Top-k") | |
| gr.Examples( | |
| examples=[ | |
| ["What's your stance on immigration?"], | |
| ["How would you describe your economic policies?"], | |
| ["What are your thoughts on the media?"], | |
| ], | |
| inputs=msg, | |
| ) | |
| submit.click(add_text, [chatbot, msg], [chatbot, msg], queue=False).then( | |
| bot, [chatbot, temperature, max_new_tokens, top_p, top_k], chatbot | |
| ) | |
| clear.click(lambda: [], outputs=[chatbot], queue=False) | |
| msg.submit(add_text, [chatbot, msg], [chatbot, msg], queue=False).then( | |
| bot, [chatbot, temperature, max_new_tokens, top_p, top_k], chatbot | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |