Spaces:
Sleeping
Sleeping
File size: 3,966 Bytes
bc3ac26 00baf90 9a5e669 bc3ac26 935dbf9 65d226c bc3ac26 935dbf9 65d226c 935dbf9 65d226c 935dbf9 bc3ac26 9a5e669 bc3ac26 65d226c bc3ac26 935dbf9 65d226c 00baf90 65d226c 00baf90 65d226c 00baf90 bc3ac26 00baf90 bc3ac26 00baf90 bc3ac26 9a5e669 495b51a bc3ac26 9a5e669 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import psutil # For tracking CPU memory usage
import torch # For tracking GPU memory usage
# Load the shared tokenizer (can be reused across all models)
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
# Define the available model names and paths
model_names = {
"Flan-T5-small": "google/flan-t5-small",
"Flan-T5-base": "google/flan-t5-base",
"Flan-T5-large": "google/flan-t5-large",
"Flan-T5-XL": "google/flan-t5-xl"
}
# Initialize variables to manage loaded model
current_model = None
current_model_name = None
def load_model(model_name):
"""Load the model if not already loaded or if switching models."""
global current_model, current_model_name
# Load the model only if it hasn't been loaded or if a different one is selected
if model_name != current_model_name:
print(f"Loading {model_name}...")
current_model = AutoModelForSeq2SeqLM.from_pretrained(model_names[model_name])
current_model_name = model_name
return current_model
def get_memory_usage():
"""Return current CPU and GPU memory usage as a formatted string."""
memory_info = psutil.virtual_memory()
cpu_memory = f"CPU Memory: {memory_info.used / (1024**3):.2f} GB / {memory_info.total / (1024**3):.2f} GB"
if torch.cuda.is_available():
gpu_memory = torch.cuda.memory_allocated() / (1024**3)
gpu_total = torch.cuda.get_device_properties(0).total_memory / (1024**3)
gpu_memory_info = f" | GPU Memory: {gpu_memory:.2f} GB / {gpu_total:.2f} GB"
else:
gpu_memory_info = " | GPU Memory: Not available"
return cpu_memory + gpu_memory_info
def respond(
message,
history: list[tuple[str, str]],
model_choice,
max_tokens,
temperature,
top_p,
):
# Load the selected model (or switch models if needed)
model = load_model(model_choice)
# Prepare the input by concatenating the history into a dialogue format
input_text = ""
for user_msg, bot_msg in history:
input_text += f"User: {user_msg} Assistant: {bot_msg} "
input_text += f"User: {message}"
# Tokenize the input text using the shared tokenizer
inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
# Generate the response using the selected Flan-T5 model
output_tokens = model.generate(
inputs["input_ids"],
max_length=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
)
# Decode and return the assistant's response
response = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
yield response
# Define the Gradio interface with memory usage widget
def update_memory_widget():
"""Update the memory usage widget dynamically."""
return get_memory_usage()
with gr.Blocks() as interface:
gr.Markdown("### Model Selection and Memory Usage")
# Render the main chat interface
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Dropdown(
choices=["Flan-T5-small", "Flan-T5-base", "Flan-T5-large", "Flan-T5-XL"],
value="Flan-T5-base", # Default selection
label="Model"
),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
)
demo.render()
# Add the memory usage widget
memory_widget = gr.Textbox(label="Memory Usage", interactive=False, value=get_memory_usage())
gr.Row([memory_widget])
# Set up a timer to update memory usage every second
interface.load(update_memory_widget, None, memory_widget, stream_every=1)
if __name__ == "__main__":
interface.launch()
|