Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| import multiprocessing | |
| import os | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer | |
| from peft import get_peft_model, LoraConfig, TaskType | |
| from datasets import load_dataset | |
| device = "cpu" | |
| training_process = None | |
| log_file = "training_status.log" | |
| # Logging function | |
| def log_status(message): | |
| with open(log_file, "w") as f: | |
| f.write(message) | |
| # Read training status | |
| def read_status(): | |
| if os.path.exists(log_file): | |
| with open(log_file, "r") as f: | |
| return f.read() | |
| return "⏳ در انتظار شروع ترینینگ..." | |
| # Function to find the text column dynamically | |
| def find_text_column(dataset): | |
| sample = dataset["train"][0] # Get the first row of the training dataset | |
| for column in sample.keys(): | |
| if isinstance(sample[column], str): # Find the first text-like column | |
| return column | |
| return None # No valid text column found | |
| # Model training function | |
| def train_model(dataset_url, model_url, epochs): | |
| try: | |
| log_status("🚀 در حال بارگیری مدل...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_url, trust_remote_code=True, torch_dtype=torch.float32, device_map="cpu" | |
| ) | |
| lora_config = LoraConfig( | |
| task_type=TaskType.CAUSAL_LM, | |
| r=8, | |
| lora_alpha=32, | |
| lora_dropout=0.1, | |
| target_modules=["q_proj", "v_proj"] | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| model.to(device) | |
| dataset = load_dataset(dataset_url) | |
| # Automatically detect the correct text column | |
| text_column = find_text_column(dataset) | |
| if not text_column: | |
| log_status("❌ خطا: ستون متنی در دیتاست یافت نشد!") | |
| return | |
| def tokenize_function(examples): | |
| return tokenizer(examples[text_column], truncation=True, padding="max_length", max_length=256) | |
| tokenized_datasets = dataset.map(tokenize_function, batched=True) | |
| train_dataset = tokenized_datasets["train"] | |
| # Automatically check for validation dataset | |
| eval_dataset = tokenized_datasets["validation"] if "validation" in tokenized_datasets else None | |
| training_args = TrainingArguments( | |
| output_dir="./deepseek_lora_cpu", | |
| evaluation_strategy="epoch" if eval_dataset else "no", # Enable evaluation if validation data exists | |
| learning_rate=5e-4, | |
| per_device_train_batch_size=1, | |
| per_device_eval_batch_size=1, | |
| num_train_epochs=int(epochs), | |
| save_strategy="epoch", | |
| save_total_limit=2, | |
| logging_dir="./logs", | |
| logging_steps=10, | |
| fp16=False, | |
| gradient_checkpointing=True, | |
| optim="adamw_torch", | |
| report_to="none" | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset # Add eval dataset if available | |
| ) | |
| log_status("🚀 ترینینگ شروع شد...") | |
| for epoch in range(int(epochs)): | |
| log_status(f"🔄 در حال اجرا: Epoch {epoch+1}/{epochs}...") | |
| trainer.train(resume_from_checkpoint=True) | |
| trainer.save_model(f"./deepseek_lora_finetuned_epoch_{epoch+1}") | |
| log_status("✅ ترینینگ کامل شد!") | |
| except Exception as e: | |
| log_status(f"❌ خطا: {str(e)}") | |
| # Start training in a separate process | |
| def start_training(dataset_url, model_url, epochs): | |
| global training_process | |
| if training_process is None or not training_process.is_alive(): | |
| training_process = multiprocessing.Process(target=train_model, args=(dataset_url, model_url, epochs)) | |
| training_process.start() | |
| return "🚀 ترینینگ شروع شد!" | |
| else: | |
| return "⚠ ترینینگ در حال اجرا است!" | |
| # Function to update the status | |
| def update_status(): | |
| return read_status() | |
| # Gradio UI | |
| with gr.Blocks() as app: | |
| gr.Markdown("# 🚀 AutoTrain DeepSeek R1 (CPU) - نمایش وضعیت لحظهای") | |
| with gr.Row(): | |
| dataset_input = gr.Textbox(label="📂 لینک دیتاست (Hugging Face)") | |
| model_input = gr.Textbox(label="🤖 مدل پایه (Hugging Face)") | |
| epochs_input = gr.Number(label="🔄 تعداد Epochs", value=3) | |
| start_button = gr.Button("🚀 شروع ترینینگ") | |
| status_output = gr.Textbox(label="📢 وضعیت ترینینگ", interactive=False) | |
| start_button.click(start_training, inputs=[dataset_input, model_input, epochs_input], outputs=status_output) | |
| status_button = gr.Button("🔄 بروزرسانی وضعیت") | |
| status_button.click(update_status, outputs=status_output) | |
| app.launch() |