| | import os
|
| | import pytesseract
|
| | from pdf2image import convert_from_path
|
| | from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
|
| | from datasets import Dataset
|
| | import torch
|
| | from flask import Flask, request, jsonify, render_template
|
| | from threading import Thread
|
| | import gradio as gr
|
| |
|
| | app = Flask(__name__)
|
| |
|
| |
|
| | input_folder = 'input'
|
| | model_folder = 'model'
|
| |
|
| |
|
| | def pdf_to_text(file_path):
|
| | images = convert_from_path(file_path)
|
| | text = ''
|
| | for image in images:
|
| | text += pytesseract.image_to_string(image, lang='ara')
|
| | return text
|
| |
|
| |
|
| | def prepare_data():
|
| | data = {'text': [], 'label': []}
|
| | labels = os.listdir(input_folder)
|
| | for label in labels:
|
| | label_folder = os.path.join(input_folder, label)
|
| | for file_name in os.listdir(label_folder):
|
| | file_path = os.path.join(label_folder, file_name)
|
| | text = pdf_to_text(file_path)
|
| | data['text'].append(text)
|
| | data['label'].append(label)
|
| | return Dataset.from_dict(data), labels
|
| |
|
| |
|
| | def load_model():
|
| | model_name = "bert-base-multilingual-cased"
|
| | tokenizer = AutoTokenizer.from_pretrained(model_folder)
|
| | model = AutoModelForSequenceClassification.from_pretrained(model_folder)
|
| | return tokenizer, model
|
| |
|
| |
|
| | def train_model():
|
| | global tokenizer, model, labels
|
| |
|
| | dataset, labels = prepare_data()
|
| | train_test_split = dataset.train_test_split(test_size=0.2)
|
| | tokenized_datasets = train_test_split.map(lambda x: tokenizer(x['text'], padding="max_length", truncation=True), batched=True)
|
| |
|
| | training_args = TrainingArguments(
|
| | output_dir=model_folder,
|
| | evaluation_strategy="epoch",
|
| | learning_rate=2e-5,
|
| | per_device_train_batch_size=16,
|
| | per_device_eval_batch_size=16,
|
| | num_train_epochs=3,
|
| | weight_decay=0.01,
|
| | )
|
| |
|
| | model = AutoModelForSequenceClassification.from_pretrained("bert-base-multilingual-cased", num_labels=len(labels))
|
| |
|
| | trainer = Trainer(
|
| | model=model,
|
| | args=training_args,
|
| | train_dataset=tokenized_datasets['train'],
|
| | eval_dataset=tokenized_datasets['test'],
|
| | )
|
| |
|
| | trainer.train()
|
| |
|
| | model.save_pretrained(model_folder)
|
| | tokenizer.save_pretrained(model_folder)
|
| | return "Model trained and saved!"
|
| |
|
| |
|
| | def classify_document(file_path):
|
| | text = pdf_to_text(file_path)
|
| | inputs = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True)
|
| | outputs = model(**inputs)
|
| | predictions = torch.argmax(outputs.logits, dim=-1)
|
| | label = labels[predictions.item()]
|
| | return label, text
|
| |
|
| |
|
| | @app.route('/')
|
| | def home():
|
| | return render_template('index.html')
|
| |
|
| | @app.route('/train', methods=['POST'])
|
| | def train():
|
| | message = train_model()
|
| | return jsonify({'message': message})
|
| |
|
| | @app.route('/classify', methods=['POST'])
|
| | def classify():
|
| | if 'file' not in request.files:
|
| | return jsonify({'error': 'No file provided'}), 400
|
| |
|
| | file = request.files['file']
|
| | if file.filename == '':
|
| | return jsonify({'error': 'No file selected'}), 400
|
| |
|
| | file_path = os.path.join('uploads', file.filename)
|
| | file.save(file_path)
|
| |
|
| | label, text = classify_document(file_path)
|
| |
|
| | return jsonify({'label': label, 'text': text})
|
| |
|
| | def run_flask():
|
| | if os.path.exists(model_folder):
|
| | global tokenizer, model, labels
|
| | tokenizer, model = load_model()
|
| | labels = os.listdir(input_folder)
|
| | else:
|
| | tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
|
| | model = None
|
| | labels = []
|
| | app.run(port=5000)
|
| |
|
| |
|
| | def run_gradio():
|
| | def classify(text):
|
| | inputs = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True)
|
| | outputs = model(**inputs)
|
| | predictions = torch.argmax(outputs.logits, dim=-1)
|
| | label = labels[predictions.item()]
|
| | return label
|
| |
|
| | gr.Interface(fn=classify, inputs="text", outputs="text").launch(server_name="0.0.0.0", server_port=7860)
|
| |
|
| | if __name__ == '__main__':
|
| | Thread(target=run_flask).start()
|
| | Thread(target=run_gradio).start()
|
| |
|