HedronCreeper commited on
Commit
92923aa
·
verified ·
1 Parent(s): 4fbab13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -68
app.py CHANGED
@@ -1,80 +1,82 @@
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
3
  from datasets import load_dataset
4
  import torch
5
  import os
6
 
7
- MODEL_DIR = "./model"
8
-
9
- def train_model():
10
- model_name = "distilgpt2"
11
-
12
- tokenizer = AutoTokenizer.from_pretrained(model_name)
13
- model = AutoModelForCausalLM.from_pretrained(model_name)
14
-
15
- dataset = load_dataset("text", data_files={"train": "data.txt"})
16
-
17
- def tokenize(example):
18
- return tokenizer(example["text"], truncation=True, padding="max_length", max_length=128)
19
-
20
- tokenized = dataset.map(tokenize, batched=True)
21
-
22
- training_args = TrainingArguments(
23
- output_dir=MODEL_DIR,
24
- per_device_train_batch_size=2,
25
- num_train_epochs=2,
26
- logging_steps=10,
27
- save_steps=50
28
- )
29
-
30
- trainer = Trainer(
31
- model=model,
32
- args=training_args,
33
- train_dataset=tokenized["train"],
34
  )
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  trainer.train()
 
37
 
38
- model.save_pretrained(MODEL_DIR)
39
- tokenizer.save_pretrained(MODEL_DIR)
40
-
41
- return "Training complete!"
42
-
43
- def load_model():
44
- tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
45
- model = AutoModelForCausalLM.from_pretrained(MODEL_DIR)
46
- return tokenizer, model
47
-
48
- def chat(user_input):
49
- if not os.path.exists(MODEL_DIR):
50
- return "Model not trained yet. Click Train first."
51
-
52
- tokenizer, model = load_model()
53
-
54
- prompt = f"User: {user_input}\nAssistant:"
55
- inputs = tokenizer(prompt, return_tensors="pt")
56
-
57
- outputs = model.generate(
58
- **inputs,
59
- max_length=100,
60
- do_sample=True,
61
- temperature=0.7
62
- )
63
-
64
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
65
 
66
  with gr.Blocks() as demo:
67
- gr.Markdown("# 🤖 My First AI")
68
-
69
- train_btn = gr.Button("Train Model")
70
- output = gr.Textbox()
71
-
72
- train_btn.click(train_model, outputs=output)
73
-
74
- user_input = gr.Textbox(label="Your message")
75
- chat_output = gr.Textbox(label="AI Response")
76
-
77
- send_btn = gr.Button("Send")
78
- send_btn.click(chat, inputs=user_input, outputs=chat_output)
79
-
80
- demo.launch()
 
1
+ # app.py
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
4
  from datasets import load_dataset
5
  import torch
6
  import os
7
 
8
+ # -----------------------------
9
+ # 1️⃣ Model setup
10
+ # -----------------------------
11
+ MODEL_DIR = "model"
12
+ MODEL_NAME = "sshleifer/tiny-gpt2" # tiny GPT-2, CPU-friendly
13
+
14
+ # Load tokenizer & model
15
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
16
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
17
+
18
+ # Fix padding issue
19
+ tokenizer.pad_token = tokenizer.eos_token
20
+
21
+ # -----------------------------
22
+ # 2️⃣ Dataset setup
23
+ # -----------------------------
24
+ # Make sure you have 'data.txt' in the same folder as app.py
25
+ dataset = load_dataset("text", data_files="data.txt")
26
+
27
+ def tokenize(example):
28
+ return tokenizer(
29
+ example["text"],
30
+ truncation=True,
31
+ padding="max_length",
32
+ max_length=64 # small for CPU
 
 
33
  )
34
 
35
+ tokenized_dataset = dataset.map(tokenize, batched=True)
36
+
37
+ # -----------------------------
38
+ # 3️⃣ Training setup
39
+ # -----------------------------
40
+ training_args = TrainingArguments(
41
+ output_dir=MODEL_DIR,
42
+ overwrite_output_dir=True,
43
+ per_device_train_batch_size=1, # CPU-friendly
44
+ num_train_epochs=1, # short test run
45
+ logging_steps=5,
46
+ save_steps=20,
47
+ save_total_limit=1
48
+ )
49
+
50
+ trainer = Trainer(
51
+ model=model,
52
+ args=training_args,
53
+ train_dataset=tokenized_dataset["train"]
54
+ )
55
+
56
+ # -----------------------------
57
+ # 4️⃣ Gradio interface
58
+ # -----------------------------
59
+ def train_model():
60
  trainer.train()
61
+ return "✅ Training complete! Model saved to /model"
62
 
63
+ def generate_text(prompt):
64
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True)
65
+ output = model.generate(**inputs, max_length=64, pad_token_id=tokenizer.eos_token_id)
66
+ return tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  with gr.Blocks() as demo:
69
+ gr.Markdown("# Tiny AI Training Demo")
70
+
71
+ with gr.Tab("Train Model"):
72
+ train_button = gr.Button("Train")
73
+ train_output = gr.Textbox(label="Logs")
74
+ train_button.click(train_model, outputs=train_output)
75
+
76
+ with gr.Tab("Generate Text"):
77
+ prompt_input = gr.Textbox(label="Prompt")
78
+ generate_button = gr.Button("Generate")
79
+ generate_output = gr.Textbox(label="Output")
80
+ generate_button.click(generate_text, inputs=prompt_input, outputs=generate_output)
81
+
82
+ demo.launch(share=True)