flamiry commited on
Commit
d502ee2
·
verified ·
1 Parent(s): c66db8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -1,12 +1,16 @@
1
  import gradio as gr
2
- from transformers import GPT2Tokenizer, GPT2LMHeadModel
3
  import torch
4
  from datasets import load_dataset
5
  import spaces
6
 
7
  # Load model once at startup
8
- model = GPT2LMHeadModel.from_pretrained("gpt2")
9
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
 
 
 
 
10
  tokenizer.pad_token = tokenizer.eos_token
11
 
12
  @spaces.GPU
@@ -33,6 +37,9 @@ def train_model():
33
  loss = outputs.loss
34
  loss.backward()
35
  optimizer.step()
 
 
 
36
 
37
  return f"✅ Training complete! Final Loss: {loss.item():.4f}"
38
  except Exception as e:
@@ -62,6 +69,4 @@ with gr.Blocks() as demo:
62
  prompt_input = gr.Textbox(label="Prompt", placeholder="Mačka je...")
63
  gen_btn = gr.Button("Generate")
64
  gen_output = gr.Textbox(label="Generated Text", interactive=False)
65
- gen_btn.click(generate_text, inputs=prompt_input, outputs=gen_output)
66
-
67
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
  from datasets import load_dataset
5
  import spaces
6
 
7
  # Load model once at startup
8
+ try:
9
+ model = AutoModelForCausalLM.from_pretrained("flamiry/first")
10
+ tokenizer = AutoTokenizer.from_pretrained("flamiry/first")
11
+ except:
12
+ model = AutoModelForCausalLM.from_pretrained("gpt2")
13
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
14
  tokenizer.pad_token = tokenizer.eos_token
15
 
16
  @spaces.GPU
 
37
  loss = outputs.loss
38
  loss.backward()
39
  optimizer.step()
40
+
41
+ model.push_to_hub("flamiry/first")
42
+ tokenizer.push_to_hub("flamiry/first")
43
 
44
  return f"✅ Training complete! Final Loss: {loss.item():.4f}"
45
  except Exception as e:
 
69
  prompt_input = gr.Textbox(label="Prompt", placeholder="Mačka je...")
70
  gen_btn = gr.Button("Generate")
71
  gen_output = gr.Textbox(label="Generated Text", interactive=False)
72
+ gen_btn.click(generate_text, inputs=prompt_input, outputs=gen_output)