parjanya20 commited on
Commit
40ba062
·
verified ·
1 Parent(s): 700396f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -1
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from verbatim_llm import TokenSwapProcessor
 
5
 
6
  # Predefined model pairs
7
  MODEL_PAIRS = {
@@ -14,6 +15,32 @@ MODEL_PAIRS = {
14
  loaded_models = {}
15
  current_pair = None
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def load_models(model_pair):
18
  global loaded_models, current_pair
19
 
@@ -21,6 +48,10 @@ def load_models(model_pair):
21
  return "Models already loaded!"
22
 
23
  try:
 
 
 
 
24
  main_model_name, aux_model_name = MODEL_PAIRS[model_pair]
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
 
@@ -87,7 +118,9 @@ with gr.Blocks(title="Verbatim-LLM Demo") as app:
87
  value=list(MODEL_PAIRS.keys())[0],
88
  label="Model Pair"
89
  )
90
- load_btn = gr.Button("Load Models", variant="primary")
 
 
91
 
92
  status = gr.Textbox(label="Status", interactive=False)
93
 
@@ -114,6 +147,11 @@ with gr.Blocks(title="Verbatim-LLM Demo") as app:
114
  outputs=[status]
115
  )
116
 
 
 
 
 
 
117
  generate_btn.click(
118
  fn=lambda p, t: (generate_text(p, t, False), generate_text(p, t, True)),
119
  inputs=[prompt_box, max_tokens],
 
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from verbatim_llm import TokenSwapProcessor
5
+ import gc
6
 
7
  # Predefined model pairs
8
  MODEL_PAIRS = {
 
15
  loaded_models = {}
16
  current_pair = None
17
 
18
+ def clear_models():
19
+ global loaded_models, current_pair
20
+
21
+ try:
22
+ # Clear models from memory
23
+ if loaded_models:
24
+ # Move models to CPU if they were on GPU
25
+ for key, value in loaded_models.items():
26
+ if hasattr(value, 'to'):
27
+ value.to('cpu')
28
+ del value
29
+
30
+ loaded_models = {}
31
+ current_pair = None
32
+
33
+ # Force garbage collection
34
+ gc.collect()
35
+
36
+ # Clear GPU cache if available
37
+ if torch.cuda.is_available():
38
+ torch.cuda.empty_cache()
39
+
40
+ return "✅ Models cleared from memory"
41
+ except Exception as e:
42
+ return f"❌ Error clearing models: {str(e)}"
43
+
44
  def load_models(model_pair):
45
  global loaded_models, current_pair
46
 
 
48
  return "Models already loaded!"
49
 
50
  try:
51
+ # Clear existing models first if switching
52
+ if loaded_models:
53
+ clear_models()
54
+
55
  main_model_name, aux_model_name = MODEL_PAIRS[model_pair]
56
  device = "cuda" if torch.cuda.is_available() else "cpu"
57
 
 
118
  value=list(MODEL_PAIRS.keys())[0],
119
  label="Model Pair"
120
  )
121
+ with gr.Column():
122
+ load_btn = gr.Button("Load Models", variant="primary")
123
+ clear_btn = gr.Button("Clear Models", variant="secondary")
124
 
125
  status = gr.Textbox(label="Status", interactive=False)
126
 
 
147
  outputs=[status]
148
  )
149
 
150
+ clear_btn.click(
151
+ fn=clear_models,
152
+ outputs=[status]
153
+ )
154
+
155
  generate_btn.click(
156
  fn=lambda p, t: (generate_text(p, t, False), generate_text(p, t, True)),
157
  inputs=[prompt_box, max_tokens],