Your Name commited on
Commit
1d9f921
·
1 Parent(s): 03c29a9

Add @spaces.GPU decorator for ZeroGPU support

Browse files
Files changed (1) hide show
  1. app.py +22 -21
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
@@ -18,40 +19,40 @@ MODELS = {
18
  }
19
  }
20
 
21
- # Global cache for loaded models
22
- loaded_models = {}
23
-
24
- def load_model(model_name):
25
- """Load model and tokenizer, cache if already loaded"""
26
- if model_name not in loaded_models:
27
- tokenizer = AutoTokenizer.from_pretrained(model_name)
28
- model = AutoModelForCausalLM.from_pretrained(
29
- model_name,
30
- torch_dtype=torch.float16,
31
- device_map="auto"
32
- )
33
- loaded_models[model_name] = (model, tokenizer)
34
-
35
- return loaded_models[model_name]
36
 
 
37
  def generate_text(model_name, prompt, max_tokens, temperature, top_p):
38
- """Generate text using selected model"""
 
 
39
  try:
40
- model, tokenizer = load_model(model_name)
 
 
 
 
 
 
 
 
41
 
42
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
43
 
44
  with torch.no_grad():
45
- outputs = model.generate(
46
  **inputs,
47
  max_new_tokens=max_tokens,
48
  temperature=temperature,
49
  top_p=top_p,
50
  do_sample=True,
51
- pad_token_id=tokenizer.eos_token_id
52
  )
53
 
54
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
55
  return response
56
 
57
  except Exception as e:
 
1
  import gradio as gr
2
+ import spaces
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
 
 
19
  }
20
  }
21
 
22
+ # Global storage for models (will be loaded on GPU)
23
+ current_model = None
24
+ current_tokenizer = None
25
+ current_model_name = None
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ @spaces.GPU
28
  def generate_text(model_name, prompt, max_tokens, temperature, top_p):
29
+ """Generate text using selected model with ZeroGPU"""
30
+ global current_model, current_tokenizer, current_model_name
31
+
32
  try:
33
+ # Load model if not loaded or different model selected
34
+ if current_model is None or current_model_name != model_name:
35
+ current_tokenizer = AutoTokenizer.from_pretrained(model_name)
36
+ current_model = AutoModelForCausalLM.from_pretrained(
37
+ model_name,
38
+ torch_dtype=torch.float16,
39
+ device_map="auto"
40
+ )
41
+ current_model_name = model_name
42
 
43
+ inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
44
 
45
  with torch.no_grad():
46
+ outputs = current_model.generate(
47
  **inputs,
48
  max_new_tokens=max_tokens,
49
  temperature=temperature,
50
  top_p=top_p,
51
  do_sample=True,
52
+ pad_token_id=current_tokenizer.eos_token_id
53
  )
54
 
55
+ response = current_tokenizer.decode(outputs[0], skip_special_tokens=True)
56
  return response
57
 
58
  except Exception as e: