cdelaunay commited on
Commit
953e037
·
verified ·
1 Parent(s): 38d458a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -0
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import torch
4
+ from transformers import LayoutLMv2ForSequenceClassification, LayoutLMv2Processor
5
+
6
+ # Chargement du modèle et du processeur
7
+ model_ft = LayoutLMv2ForSequenceClassification.from_pretrained("Tornaid/LayoutLMv2_D3_Classifier")
8
+ processor_ft = LayoutLMv2Processor.from_pretrained("Tornaid/LayoutLMv2_D3_Classifier")
9
+
10
+ label2id = {
11
+ 'budget': 0, 'form': 1, 'file_folder': 2, 'invoice': 3, 'email': 4,
12
+ 'handwritten': 5, 'id_pieces': 6, 'advertisement': 7, 'carte postale': 8,
13
+ 'scientific_publication': 9, 'news_article': 10, 'scientific_report': 11,
14
+ 'resume': 12, 'letter': 13, 'presentation': 14, 'questionnaire': 15,
15
+ 'memo': 16, 'paye': 17, 'specification': 18
16
+ }
17
+
18
+ id2label = {id: label for label, id in label2id.items()}
19
+
20
+ def predict_image_classification(image):
21
+ """Effectue la prédiction de classification sur l'image donnée."""
22
+ inputs = processor_ft(image, return_tensors="pt", truncation=True, max_length=512)
23
+ outputs = model_ft(**inputs)
24
+ prediction_index = outputs.logits.argmax(-1).item()
25
+ return id2label[prediction_index]
26
+
27
+ # Interface Streamlit
28
+ st.title("Classification de documents")
29
+
30
+ uploaded_file = st.file_uploader("Choisissez un fichier image", type=["jpg", "jpeg", "png"])
31
+ if uploaded_file is not None:
32
+ image = Image.open(uploaded_file).convert("RGB")
33
+ st.image(image, caption="Image chargée", use_column_width=True)
34
+ if st.button("Classer"):
35
+ label_pred = predict_image_classification(image)
36
+ st.write(f"Prédiction: {label_pred}")