Spaces:
Runtime error
Runtime error
| try: | |
| import detectron2 | |
| except: | |
| import os | |
| os.system('pip install git+https://github.com/facebookresearch/detectron2.git') | |
| import streamlit as st | |
| from PIL import Image | |
| import torch | |
| from transformers import LayoutLMv2ForSequenceClassification, LayoutLMv2Processor | |
| # Chargement du modèle et du processeur | |
| model_ft = LayoutLMv2ForSequenceClassification.from_pretrained("Tornaid/LayoutLMv2_D3_Classifier") | |
| processor_ft = LayoutLMv2Processor.from_pretrained("Tornaid/LayoutLMv2_D3_Classifier") | |
| label2id = { | |
| 'budget': 0, 'form': 1, 'file_folder': 2, 'invoice': 3, 'email': 4, | |
| 'handwritten': 5, 'id_pieces': 6, 'advertisement': 7, 'carte postale': 8, | |
| 'scientific_publication': 9, 'news_article': 10, 'scientific_report': 11, | |
| 'resume': 12, 'letter': 13, 'presentation': 14, 'questionnaire': 15, | |
| 'memo': 16, 'paye': 17, 'specification': 18 | |
| } | |
| id2label = {id: label for label, id in label2id.items()} | |
| def predict_image_classification(image): | |
| """Effectue la prédiction de classification sur l'image donnée.""" | |
| inputs = processor_ft(image, return_tensors="pt", truncation=True, max_length=512) | |
| outputs = model_ft(**inputs) | |
| prediction_index = outputs.logits.argmax(-1).item() | |
| return id2label[prediction_index] | |
| # Interface Streamlit | |
| st.title("Classification de documents") | |
| uploaded_file = st.file_uploader("Choisissez un fichier image", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file).convert("RGB") | |
| st.image(image, caption="Image chargée", use_column_width=True) | |
| if st.button("Classer"): | |
| label_pred = predict_image_classification(image) | |
| st.write(f"Prédiction: {label_pred}") |