space / app.py
flamiry's picture
Update app.py
465015b verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from datasets import load_dataset
import spaces
from huggingface_hub import login
import os
from itertools import islice
login(token=os.environ.get("hf_token"))
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("flamiry/first")
tokenizer = AutoTokenizer.from_pretrained("flamiry/first")
tokenizer.pad_token = tokenizer.eos_token
# Move model to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
@spaces.GPU
def train_model(start, end):
start = int(start)
end = int(end)
try:
dataset = load_dataset("allenai/c4", "sk", split="train", streaming=True)
slovak_texts = [example['text'] for example in islice(dataset, start, end)]
inputs = tokenizer(
slovak_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
# Move inputs to device
inputs = {k: v.to(device) for k, v in inputs.items()}
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
for epoch in range(2):
optimizer.zero_grad()
outputs = model(**inputs, labels=inputs['input_ids'])
loss = outputs.loss
loss.backward()
optimizer.step()
model.push_to_hub("flamiry/first")
tokenizer.push_to_hub("flamiry/first")
return f"✅ Training complete! Final Loss: {loss.item():.4f}"
except Exception as e:
return f"❌ Error: {str(e)}"
@spaces.GPU
def generate_text(prompt):
try:
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
output = model.generate(input_ids, max_length=50)
return tokenizer.decode(output[0], skip_special_tokens=True)
except Exception as e:
return f"❌ Error: {str(e)}"
with gr.Blocks() as demo:
gr.Markdown("# Slovak LLM Training")
with gr.Tab("Train Model"):
gr.Markdown("Click to train the model on Slovak data")
start_input = gr.Textbox(label="Start")
end_input = gr.Textbox(label="End")
train_btn = gr.Button("Start Training", variant="primary")
train_output = gr.Textbox(label="Result", interactive=False)
train_btn.click(train_model, inputs=[start_input, end_input], outputs=train_output)
with gr.Tab("Generate Text"):
gr.Markdown("Generate Slovak text")
prompt_input = gr.Textbox(label="Prompt")
gen_btn = gr.Button("Generate")
gen_output = gr.Textbox(label="Generated Text", interactive=False)
gen_btn.click(generate_text, inputs=prompt_input, outputs=gen_output)
demo.launch()