ANISA09 commited on
Commit
e3158c7
·
verified ·
1 Parent(s): b83bc11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -81
app.py CHANGED
@@ -1,98 +1,70 @@
 
 
1
  import gradio as gr
2
- from transformers import AutoFeatureExtractor, AutoModelForImageClassification
3
- import torch
4
- from torchvision import transforms
5
- from PIL import Image
6
- import pytesseract
7
- import cv2
8
- import numpy as np
9
 
10
- # -------------------------------------------------------------
11
- # Setup OCR
12
- # -------------------------------------------------------------
13
- pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract"
14
 
15
- # -------------------------------------------------------------
16
- # Load Pretrained Vision Model
17
- # -------------------------------------------------------------
18
- # Using ResNet18 for demonstration
19
- from torchvision.models import resnet18
20
 
21
- model = resnet18(weights="IMAGENET1K_V1")
22
- model.eval()
23
 
24
- # Define transform for the model
25
- preprocess = transforms.Compose([
26
- transforms.Resize((224, 224)),
27
- transforms.ToTensor(),
28
- transforms.Normalize(
29
- mean=[0.485, 0.456, 0.406],
30
- std=[0.229, 0.224, 0.225]
31
- )
32
- ])
33
 
34
- # -------------------------------------------------------------
35
- # Certificate Verification Function
36
- # -------------------------------------------------------------
37
- REQUIRED_KEYWORDS = ["certificate", "proudly presented", "position", "organized by", "date"]
 
 
 
 
38
 
39
- def verify_certificate(image):
40
- # Ensure PIL Image
41
- if not isinstance(image, Image.Image):
42
- image = Image.fromarray(image)
43
- image = image.convert("RGB")
44
 
45
- # ------------------------------
46
- # 1️⃣ Model Prediction (generic)
47
- # ------------------------------
48
- input_tensor = preprocess(image).unsqueeze(0) # add batch dim
49
- with torch.no_grad():
50
- outputs = model(input_tensor)
51
- probs = torch.nn.functional.softmax(outputs[0], dim=0)
52
- top_prob, top_catid = torch.topk(probs, 1)
53
- model_confidence = float(top_prob.item())
54
- model_label = str(top_catid.item()) # generic label index
55
 
56
- # ------------------------------
57
- # 2️⃣ OCR Extraction
58
- # ------------------------------
59
- img_np = np.array(image)
60
- gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
61
- text = pytesseract.image_to_string(gray)
62
 
63
- # ------------------------------
64
- # 3️⃣ Heuristic Text Scoring
65
- # ------------------------------
66
- keyword_matches = sum([1 for kw in REQUIRED_KEYWORDS if kw.lower() in text.lower()])
67
- ocr_score = keyword_matches / len(REQUIRED_KEYWORDS)
68
 
69
- # ------------------------------
70
- # 4️⃣ Combine Model + OCR
71
- # ------------------------------
72
- combined_confidence = round((model_confidence + ocr_score) / 2, 4)
73
 
74
- # ------------------------------
75
- # 5️⃣ Return Result
76
- # ------------------------------
77
- result = {
78
- "model_label": model_label,
79
- "model_confidence": round(model_confidence, 4),
80
- "ocr_score": round(ocr_score, 4),
81
- "combined_confidence": combined_confidence,
82
- "text_preview": text[:300]
83
- }
84
- return result
85
 
86
- # -------------------------------------------------------------
87
- # Gradio Interface
88
- # -------------------------------------------------------------
89
  demo = gr.Interface(
90
- fn=verify_certificate,
91
- inputs=gr.Image(type="numpy", label="Upload Certificate Image"),
92
- outputs=gr.JSON(label="Verification Result"),
93
- title="Certificate Verification AI 🧠",
94
- description="Uploads a certificate image, checks for authenticity using a vision model and OCR keyword heuristics."
 
 
 
 
 
 
 
95
  )
96
 
97
  if __name__ == "__main__":
98
- demo.launch()
 
1
+ import os
2
+ import uuid
3
  import gradio as gr
4
+ from inference_sdk import InferenceHTTPClient
 
 
 
 
 
 
5
 
6
+ # Ensure uploads folder exists
7
+ UPLOAD_DIR = "uploads"
8
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
 
9
 
10
+ # Initialize Roboflow inference client
11
+ CLIENT = InferenceHTTPClient(
12
+ api_url="https://serverless.roboflow.com",
13
+ api_key="i22FWkifZzD236Hhg56U" # ⚠️ Replace with your actual API key if different
14
+ )
15
 
16
+ # Model ID (Project Slug + Version)
17
+ MODEL_ID = "detecting-fake-certificates-bj1x6/3"
18
 
 
 
 
 
 
 
 
 
 
19
 
20
+ def analyze_certificate(image):
21
+ """Save uploaded image locally and send it to Roboflow for inference."""
22
+ try:
23
+ # Save image with unique filename
24
+ filename = f"{uuid.uuid4()}.jpg"
25
+ save_path = os.path.join(UPLOAD_DIR, filename)
26
+ image.save(save_path)
27
+ print(f"[INFO] Image saved at {save_path}")
28
 
29
+ # Perform inference using Roboflow model
30
+ result = CLIENT.infer(save_path, model_id=MODEL_ID)
31
+ print("[INFO] Inference result:", result)
 
 
32
 
33
+ # Parse predictions
34
+ predictions = result.get("predictions", [])
35
+ if not predictions:
36
+ return "⚠️ No objects detected — possibly a valid certificate.", save_path
 
 
 
 
 
 
37
 
38
+ # Build readable output
39
+ output_lines = []
40
+ for pred in predictions:
41
+ cls = pred.get("class", "unknown")
42
+ conf = round(pred.get("confidence", 0) * 100, 2)
43
+ output_lines.append(f"- {cls} ({conf}% confidence)")
44
 
45
+ output_text = "✅ **Detections:**\n" + "\n".join(output_lines)
46
+ return output_text, save_path
 
 
 
47
 
48
+ except Exception as e:
49
+ print("[ERROR]", e)
50
+ return f"❌ Error during inference: {e}", None
 
51
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ # Gradio interface
 
 
54
  demo = gr.Interface(
55
+ fn=analyze_certificate,
56
+ inputs=gr.Image(type="pil", label="Upload a Certificate"),
57
+ outputs=[
58
+ gr.Textbox(label="Inference Result"),
59
+ gr.Image(label="Uploaded Image")
60
+ ],
61
+ title="Fake Certificate Detector 🧠",
62
+ description=(
63
+ "Upload a certificate image — this app will use a trained Roboflow model "
64
+ "(`detecting-fake-certificates-bj1x6/3`) to detect possible signs of forgery."
65
+ ),
66
+ allow_flagging="never"
67
  )
68
 
69
  if __name__ == "__main__":
70
+ demo.launch(server_name="0.0.0.0", server_port=7860)