File size: 3,966 Bytes
bc3ac26
00baf90
9a5e669
 
bc3ac26
935dbf9
65d226c
bc3ac26
935dbf9
65d226c
 
935dbf9
 
 
65d226c
 
935dbf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc3ac26
9a5e669
 
 
 
 
 
 
 
 
 
 
 
 
 
bc3ac26
 
 
65d226c
bc3ac26
 
 
 
935dbf9
 
65d226c
00baf90
 
 
 
 
 
65d226c
00baf90
 
65d226c
00baf90
 
 
bc3ac26
 
00baf90
 
bc3ac26
00baf90
 
 
bc3ac26
9a5e669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495b51a
bc3ac26
 
9a5e669
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import psutil  # For tracking CPU memory usage
import torch  # For tracking GPU memory usage

# Load the shared tokenizer (can be reused across all models)
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")

# Define the available model names and paths
model_names = {
    "Flan-T5-small": "google/flan-t5-small",
    "Flan-T5-base": "google/flan-t5-base",
    "Flan-T5-large": "google/flan-t5-large",
    "Flan-T5-XL": "google/flan-t5-xl"
}

# Initialize variables to manage loaded model
current_model = None
current_model_name = None

def load_model(model_name):
    """Load the model if not already loaded or if switching models."""
    global current_model, current_model_name

    # Load the model only if it hasn't been loaded or if a different one is selected
    if model_name != current_model_name:
        print(f"Loading {model_name}...")
        current_model = AutoModelForSeq2SeqLM.from_pretrained(model_names[model_name])
        current_model_name = model_name

    return current_model

def get_memory_usage():
    """Return current CPU and GPU memory usage as a formatted string."""
    memory_info = psutil.virtual_memory()
    cpu_memory = f"CPU Memory: {memory_info.used / (1024**3):.2f} GB / {memory_info.total / (1024**3):.2f} GB"
    
    if torch.cuda.is_available():
        gpu_memory = torch.cuda.memory_allocated() / (1024**3)
        gpu_total = torch.cuda.get_device_properties(0).total_memory / (1024**3)
        gpu_memory_info = f" | GPU Memory: {gpu_memory:.2f} GB / {gpu_total:.2f} GB"
    else:
        gpu_memory_info = " | GPU Memory: Not available"

    return cpu_memory + gpu_memory_info

def respond(
    message,
    history: list[tuple[str, str]],
    model_choice,
    max_tokens,
    temperature,
    top_p,
):
    # Load the selected model (or switch models if needed)
    model = load_model(model_choice)

    # Prepare the input by concatenating the history into a dialogue format
    input_text = ""
    for user_msg, bot_msg in history:
        input_text += f"User: {user_msg} Assistant: {bot_msg} "
    input_text += f"User: {message}"

    # Tokenize the input text using the shared tokenizer
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True)

    # Generate the response using the selected Flan-T5 model
    output_tokens = model.generate(
        inputs["input_ids"],
        max_length=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
    )

    # Decode and return the assistant's response
    response = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
    yield response

# Define the Gradio interface with memory usage widget
def update_memory_widget():
    """Update the memory usage widget dynamically."""
    return get_memory_usage()

with gr.Blocks() as interface:
    gr.Markdown("### Model Selection and Memory Usage")
    
    # Render the main chat interface
    demo = gr.ChatInterface(
        respond,
        additional_inputs=[
            gr.Dropdown(
                choices=["Flan-T5-small", "Flan-T5-base", "Flan-T5-large", "Flan-T5-XL"],
                value="Flan-T5-base",  # Default selection
                label="Model"
            ),
            gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
            gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
            gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
        ],
    )
    demo.render()

    # Add the memory usage widget
    memory_widget = gr.Textbox(label="Memory Usage", interactive=False, value=get_memory_usage())
    gr.Row([memory_widget])

    # Set up a timer to update memory usage every second
    interface.load(update_memory_widget, None, memory_widget, stream_every=1)

if __name__ == "__main__":
    interface.launch()