reydeuss commited on
Commit
0d3a311
·
verified ·
1 Parent(s): 3e3dd41

update app.py to remove @gradio_cached function or whatever it's called idc anymore

Browse files
Files changed (1) hide show
  1. app.py +27 -12
app.py CHANGED
@@ -1,23 +1,38 @@
1
- # app.py for Gradio
2
  import gradio as gr
3
  import torch
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration
 
5
 
6
- @gradio_cached_function
7
- def load_model():
8
- model = T5ForConditionalGeneration.from_pretrained(
9
- "cahya/t5-base-indonesian-summarization",
10
- load_in_8bit=True,
11
- device_map="auto"
12
- )
13
- tokenizer = T5Tokenizer.from_pretrained("cahya/t5-base-indonesian-summarization")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  return model, tokenizer
15
 
16
  def summarize_text(text):
17
  if not text.strip():
18
  return "Please enter text to summarize."
19
 
20
- model, tokenizer = load_model()
21
 
22
  # Add T5 prefix
23
  input_text = f"summarize: {text}"
@@ -42,10 +57,10 @@ interface = gr.Interface(
42
  inputs=gr.Textbox(lines=10, placeholder="Enter Indonesian text here...", label="Input Text"),
43
  outputs=gr.Textbox(label="Generated Summary"),
44
  title="Indonesian Text Summarization",
45
- description="Enter Indonesian text to generate a summary using T5 model",
46
  examples=[
47
  ["Your example Indonesian text here..."]
48
  ]
49
  )
50
 
51
- interface.launch()
 
1
+ # app.py for Gradio with PEFT
2
  import gradio as gr
3
  import torch
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration
5
+ from peft import PeftModel, PeftConfig
6
 
7
+ # Load model once at startup
8
+ model = None
9
+ tokenizer = None
10
+
11
+ def load_model_once():
12
+ global model, tokenizer
13
+ if model is None:
14
+ # Load base model
15
+ base_model_name = "cahya/t5-base-indonesian-summarization"
16
+ tokenizer = T5Tokenizer.from_pretrained(base_model_name)
17
+ base_model = T5ForConditionalGeneration.from_pretrained(
18
+ base_model_name,
19
+ load_in_8bit=True, # Quantize for CPU efficiency
20
+ device_map="auto"
21
+ )
22
+
23
+ # Load your trained PEFT adapters
24
+ # Replace with your actual PEFT adapter path
25
+ model = PeftModel.from_pretrained(
26
+ base_model,
27
+ "./path-to-your-adapter" # Upload your wandb artifact files here
28
+ )
29
  return model, tokenizer
30
 
31
  def summarize_text(text):
32
  if not text.strip():
33
  return "Please enter text to summarize."
34
 
35
+ model, tokenizer = load_model_once()
36
 
37
  # Add T5 prefix
38
  input_text = f"summarize: {text}"
 
57
  inputs=gr.Textbox(lines=10, placeholder="Enter Indonesian text here...", label="Input Text"),
58
  outputs=gr.Textbox(label="Generated Summary"),
59
  title="Indonesian Text Summarization",
60
+ description="Enter Indonesian text to generate a summary using T5 model with PEFT adapters",
61
  examples=[
62
  ["Your example Indonesian text here..."]
63
  ]
64
  )
65
 
66
+ interface.launch()