Update app.py
Browse files
app.py
CHANGED
|
@@ -1,12 +1,16 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
from transformers import
|
| 3 |
import torch
|
| 4 |
from datasets import load_dataset
|
| 5 |
import spaces
|
| 6 |
|
| 7 |
# Load model once at startup
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
tokenizer.pad_token = tokenizer.eos_token
|
| 11 |
|
| 12 |
@spaces.GPU
|
|
@@ -33,6 +37,9 @@ def train_model():
|
|
| 33 |
loss = outputs.loss
|
| 34 |
loss.backward()
|
| 35 |
optimizer.step()
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
return f"✅ Training complete! Final Loss: {loss.item():.4f}"
|
| 38 |
except Exception as e:
|
|
@@ -62,6 +69,4 @@ with gr.Blocks() as demo:
|
|
| 62 |
prompt_input = gr.Textbox(label="Prompt", placeholder="Mačka je...")
|
| 63 |
gen_btn = gr.Button("Generate")
|
| 64 |
gen_output = gr.Textbox(label="Generated Text", interactive=False)
|
| 65 |
-
gen_btn.click(generate_text, inputs=prompt_input, outputs=gen_output)
|
| 66 |
-
|
| 67 |
-
demo.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 3 |
import torch
|
| 4 |
from datasets import load_dataset
|
| 5 |
import spaces
|
| 6 |
|
| 7 |
# Load model once at startup
|
| 8 |
+
try:
|
| 9 |
+
model = AutoModelForCausalLM.from_pretrained("flamiry/first")
|
| 10 |
+
tokenizer = AutoTokenizer.from_pretrained("flamiry/first")
|
| 11 |
+
except:
|
| 12 |
+
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
| 13 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| 14 |
tokenizer.pad_token = tokenizer.eos_token
|
| 15 |
|
| 16 |
@spaces.GPU
|
|
|
|
| 37 |
loss = outputs.loss
|
| 38 |
loss.backward()
|
| 39 |
optimizer.step()
|
| 40 |
+
|
| 41 |
+
model.push_to_hub("flamiry/first")
|
| 42 |
+
tokenizer.push_to_hub("flamiry/first")
|
| 43 |
|
| 44 |
return f"✅ Training complete! Final Loss: {loss.item():.4f}"
|
| 45 |
except Exception as e:
|
|
|
|
| 69 |
prompt_input = gr.Textbox(label="Prompt", placeholder="Mačka je...")
|
| 70 |
gen_btn = gr.Button("Generate")
|
| 71 |
gen_output = gr.Textbox(label="Generated Text", interactive=False)
|
| 72 |
+
gen_btn.click(generate_text, inputs=prompt_input, outputs=gen_output)
|
|
|
|
|
|