arpit-gour02 commited on
Commit
3f0bea0
·
unverified ·
1 Parent(s): b59b8e6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import models, transforms
4
+ from huggingface_hub import hf_hub_download
5
+
6
+ # --- 1. CONFIGURATION ---
7
+ REPO_ID = "arpit-gour02/document-classification"
8
+ MODEL_FILENAME = "resnet50_epoch_5.pth"
9
+
10
+ class_names = [
11
+ 'letter', 'form', 'email', 'handwritten', 'advertisement', 'scientific report',
12
+ 'scientific publication', 'specification', 'file folder', 'news article',
13
+ 'budget', 'invoice', 'presentation', 'questionnaire', 'resume', 'memo'
14
+ ]
15
+
16
+ # --- 2. LOAD MODEL FROM HUB ---
17
+ def load_model_from_hub():
18
+ print(f"Downloading {MODEL_FILENAME} from {REPO_ID}...")
19
+ model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
20
+
21
+ # Initialize Architecture
22
+ model = models.resnet50(num_classes=16)
23
+
24
+ # Load Weights
25
+ checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
26
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
27
+ model.load_state_dict(checkpoint['state_dict'])
28
+ else:
29
+ model.load_state_dict(checkpoint)
30
+
31
+ model.eval()
32
+ return model
33
+
34
+ model = load_model_from_hub()
35
+
36
+ # --- 3. PREPROCESSING ---
37
+ transform = transforms.Compose([
38
+ transforms.Resize((224, 224)),
39
+ transforms.ToTensor(),
40
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
41
+ ])
42
+
43
+ # --- 4. PREDICTION FUNCTION ---
44
+ def predict(image):
45
+ if image is None:
46
+ return None
47
+ image_tensor = transform(image).unsqueeze(0)
48
+ with torch.no_grad():
49
+ outputs = model(image_tensor)
50
+ probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
51
+
52
+ return {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}
53
+
54
+ # --- 5. LAUNCH INTERFACE ---
55
+ interface = gr.Interface(
56
+ fn=predict,
57
+ inputs=gr.Image(type="pil"),
58
+ outputs=gr.Label(num_top_classes=3),
59
+ title="Document Classifier (ResNet50)",
60
+ description=f"Classifies documents into 16 categories using a ResNet50 model hosted at <a href='https://huggingface.co/{REPO_ID}'>{REPO_ID}</a>.",
61
+ examples=[["sample_invoice.jpg"]] # You can upload a sample image to the space if you want examples
62
+ )
63
+
64
+ if __name__ == "__main__":
65
+ interface.launch()