Fix: Handle 'tweet' column in Arabic dataset correctly
Browse files
app.py
CHANGED
|
@@ -17,8 +17,12 @@ tokenizer = AutoTokenizer.from_pretrained('distilbert-base-multilingual-cased')
|
|
| 17 |
model = AutoModelForSequenceClassification.from_pretrained('distilbert-base-multilingual-cased', num_labels=3)
|
| 18 |
|
| 19 |
def preprocess_function(examples):
|
|
|
|
|
|
|
|
|
|
| 20 |
# Tokenize the Arabic text
|
| 21 |
-
encoding = tokenizer(examples[
|
|
|
|
| 22 |
# Map label to indices
|
| 23 |
if 'label' in examples:
|
| 24 |
encoding['labels'] = examples['label']
|
|
@@ -26,8 +30,8 @@ def preprocess_function(examples):
|
|
| 26 |
encoding['labels'] = examples['sentiment']
|
| 27 |
return encoding
|
| 28 |
|
| 29 |
-
# Preprocess the dataset
|
| 30 |
-
tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=['
|
| 31 |
|
| 32 |
def train_model(epochs, batch_size, learning_rate):
|
| 33 |
"""Fine-tune DistilBERT on Arabic sentiment dataset (Saudi dialect)"""
|
|
@@ -53,32 +57,32 @@ def train_model(epochs, batch_size, learning_rate):
|
|
| 53 |
# Start training
|
| 54 |
trainer.train()
|
| 55 |
|
| 56 |
-
return "\u270d✅
|
| 57 |
-
|
| 58 |
except Exception as e:
|
| 59 |
return f"❌ خطأ أثناء التدريب: {str(e)}"
|
| 60 |
|
| 61 |
# Create Gradio interface
|
| 62 |
with gr.Blocks(title="DistilBERT Arabic Sentiment Training") as demo:
|
| 63 |
gr.Markdown("""
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
|
| 75 |
with gr.Row():
|
| 76 |
with gr.Column():
|
| 77 |
gr.Markdown("### إعدادات التدريب")
|
| 78 |
-
epochs = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="
|
| 79 |
batch_size = gr.Slider(minimum=8, maximum=64, value=32, step=8, label="Batch Size")
|
| 80 |
learning_rate = gr.Slider(minimum=1e-5, maximum=1e-3, value=2e-5, step=1e-5, label="Learning Rate")
|
| 81 |
-
|
| 82 |
with gr.Column():
|
| 83 |
gr.Markdown("### حالة التدريب")
|
| 84 |
output_text = gr.Textbox(label="المخرجات", lines=10, interactive=False)
|
|
@@ -91,12 +95,12 @@ with gr.Blocks(title="DistilBERT Arabic Sentiment Training") as demo:
|
|
| 91 |
)
|
| 92 |
|
| 93 |
gr.Markdown("""
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
|
| 101 |
if __name__ == "__main__":
|
| 102 |
demo.launch()
|
|
|
|
| 17 |
model = AutoModelForSequenceClassification.from_pretrained('distilbert-base-multilingual-cased', num_labels=3)
|
| 18 |
|
| 19 |
def preprocess_function(examples):
|
| 20 |
+
# Check which column contains the text (tweet or text)
|
| 21 |
+
text_column = 'tweet' if 'tweet' in examples else 'text'
|
| 22 |
+
|
| 23 |
# Tokenize the Arabic text
|
| 24 |
+
encoding = tokenizer(examples[text_column], truncation=True, padding='max_length', max_length=128)
|
| 25 |
+
|
| 26 |
# Map label to indices
|
| 27 |
if 'label' in examples:
|
| 28 |
encoding['labels'] = examples['label']
|
|
|
|
| 30 |
encoding['labels'] = examples['sentiment']
|
| 31 |
return encoding
|
| 32 |
|
| 33 |
+
# Preprocess the dataset - only keep label and input_ids columns
|
| 34 |
+
tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=dataset['train'].column_names)
|
| 35 |
|
| 36 |
def train_model(epochs, batch_size, learning_rate):
|
| 37 |
"""Fine-tune DistilBERT on Arabic sentiment dataset (Saudi dialect)"""
|
|
|
|
| 57 |
# Start training
|
| 58 |
trainer.train()
|
| 59 |
|
| 60 |
+
return "\u270d✅ تم التدريب بنجاح!\n" + \
|
| 61 |
+
f"النموذج محفوظ في ./results\nمعدل التعلم: {learning_rate}\nعدد الحقب: {epochs}\nBatch Size: {batch_size}"
|
| 62 |
except Exception as e:
|
| 63 |
return f"❌ خطأ أثناء التدريب: {str(e)}"
|
| 64 |
|
| 65 |
# Create Gradio interface
|
| 66 |
with gr.Blocks(title="DistilBERT Arabic Sentiment Training") as demo:
|
| 67 |
gr.Markdown("""
|
| 68 |
+
# 🚀 تدريب نموذج DistilBERT العربي
|
| 69 |
+
|
| 70 |
+
ضبط نموذج **DistilBERT** على تحليل المشاعر باللغة العربية (اللهجة السعودية)
|
| 71 |
+
|
| 72 |
+
### معلومات النموذج:
|
| 73 |
+
- **النموذج الأساسي**: distilbert-base-multilingual-cased (67M معامل)
|
| 74 |
+
- **المهمة**: تصنيف النصوص (المتعد اللغات)
|
| 75 |
+
- **قاعدة البيانات**: arbml/Arabic_Sentiment_Twitter_Corpus (58.8k مثال)
|
| 76 |
+
- **اللغة**: العربية (اللهجة السعودية والخليجية)
|
| 77 |
+
""")
|
| 78 |
|
| 79 |
with gr.Row():
|
| 80 |
with gr.Column():
|
| 81 |
gr.Markdown("### إعدادات التدريب")
|
| 82 |
+
epochs = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="عدد الحقب (Epochs)")
|
| 83 |
batch_size = gr.Slider(minimum=8, maximum=64, value=32, step=8, label="Batch Size")
|
| 84 |
learning_rate = gr.Slider(minimum=1e-5, maximum=1e-3, value=2e-5, step=1e-5, label="Learning Rate")
|
| 85 |
+
|
| 86 |
with gr.Column():
|
| 87 |
gr.Markdown("### حالة التدريب")
|
| 88 |
output_text = gr.Textbox(label="المخرجات", lines=10, interactive=False)
|
|
|
|
| 95 |
)
|
| 96 |
|
| 97 |
gr.Markdown("""
|
| 98 |
+
### تفاصيل التدريب:
|
| 99 |
+
- **مرحلة البناء**: GPU مجاني (مباشر عبر Hugging Face Spaces)
|
| 100 |
+
- **وقت المتوقع**: 5-10 دقائق (GPU) أو 15-20 دقيقة (CPU)
|
| 101 |
+
- **مخرجات النموذج**: محفوظ عند ./results
|
| 102 |
+
- **الاستخدام**: النصوص العربية فقط
|
| 103 |
+
""")
|
| 104 |
|
| 105 |
if __name__ == "__main__":
|
| 106 |
demo.launch()
|