ds_553_cs1 / app.py
sdkrastev's picture
cleaning code comments
9a0a441
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import pipeline
client = InferenceClient("distilgpt2")
# Default mode is remote so that model doesn't load yet
run_local = False
pipe = None # Only initialize when switching to local
def respond(message):
if run_local: # Run locally, loading the model takes time
global pipe
if pipe is None:
pipe = pipeline("text-generation", model="distilgpt2", device="cpu")
response = pipe(message, max_new_tokens=50, do_sample=True)[0]['generated_text']
return response
else: # Run remotely through the API, should be faster
response = client.text_generation(message, max_new_tokens=50)
return response
def set_mode(selected_mode):
global run_local, pipe
run_local = (selected_mode == "Local")
pipe = None
return f"Switched to {'Local' if run_local else 'Remote'} Mode"
def update_chat(history, message):
response = respond(message)
history.append((message, response))
return history, ""
with gr.Blocks() as demo:
mode = gr.Radio(choices=["Remote", "Local"], value="Remote", label="Select Mode", interactive=True)
mode.change(set_mode, mode, None)
chatbot = gr.Chatbot()
message_input = gr.Textbox(label="Enter your message", placeholder="Type your message here...")
send_button = gr.Button("Send")
send_button.click(update_chat, [chatbot, message_input], [chatbot, message_input])
demo.launch()