raahinaez commited on
Commit
7c1270e
·
verified ·
1 Parent(s): 783ec1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -0
app.py CHANGED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import fitz # PyMuPDF
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ from peft import PeftModel
6
+ import json
7
+
8
+ # -----------------------
9
+ # CONFIG
10
+ # -----------------------
11
+ MODEL_NAME = "prajjwal1/bert-tiny"
12
+ LORA_PATH = "./lora_adapter"
13
+ LABEL_FILE = "./lora_adapter/label_map.json"
14
+
15
+ # -----------------------
16
+ # LOAD LABEL MAP
17
+ # -----------------------
18
+ with open(LABEL_FILE, "r") as f:
19
+ label_map = json.load(f)
20
+
21
+ id2label = {int(k): v for k, v in label_map.items()}
22
+
23
+ # -----------------------
24
+ # LOAD MODEL
25
+ # -----------------------
26
+ @st.cache_resource
27
+ def load_model():
28
+ base_model = AutoModelForSequenceClassification.from_pretrained(
29
+ MODEL_NAME,
30
+ num_labels=len(id2label),
31
+ id2label=id2label
32
+ )
33
+ model = PeftModel.from_pretrained(base_model, LORA_PATH)
34
+ tokenizer = AutoTokenizer.from_pretrained(LORA_PATH)
35
+ model.eval()
36
+ return model, tokenizer
37
+
38
+ model, tokenizer = load_model()
39
+
40
+ # -----------------------
41
+ # PDF TEXT EXTRACTION
42
+ # -----------------------
43
+ def extract_text_from_pdf(uploaded_file):
44
+ doc = fitz.open(stream=uploaded_file.read(), filetype="pdf")
45
+ text = ""
46
+ for page in doc:
47
+ text += page.get_text()
48
+ return text.strip()
49
+
50
+ # -----------------------
51
+ # STREAMLIT UI
52
+ # -----------------------
53
+ st.set_page_config(page_title="Document Classifier", layout="centered")
54
+
55
+ st.title("📄 Document Classification App")
56
+ st.write("Upload a PDF and classify the document type.")
57
+
58
+ uploaded_file = st.file_uploader("Upload PDF", type=["pdf"])
59
+
60
+ if uploaded_file:
61
+ with st.spinner("Extracting text..."):
62
+ text = extract_text_from_pdf(uploaded_file)
63
+
64
+ if len(text) < 20:
65
+ st.error("Not enough text extracted from PDF.")
66
+ else:
67
+ with st.spinner("Classifying document..."):
68
+ inputs = tokenizer(
69
+ text,
70
+ return_tensors="pt",
71
+ truncation=True,
72
+ max_length=256
73
+ )
74
+
75
+ with torch.no_grad():
76
+ outputs = model(**inputs)
77
+
78
+ pred_id = torch.argmax(outputs.logits, dim=-1).item()
79
+ prediction = model.config.id2label[pred_id]
80
+
81
+ st.success(f"✅ Predicted Document Type: **{prediction.upper()}**")