File size: 8,322 Bytes
c2fa27a
 
f3b82b8
58f1841
 
 
 
 
 
a5fe846
 
 
 
58f1841
 
 
 
 
 
 
 
 
a5fe846
58f1841
a5fe846
58f1841
a5fe846
58f1841
 
 
 
 
 
 
 
 
 
 
 
 
 
f3b82b8
86c2ef8
a5fe846
f3b82b8
 
 
 
a5fe846
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58f1841
c2fa27a
58f1841
c2fa27a
86c2ef8
58f1841
86c2ef8
58f1841
86c2ef8
58f1841
 
 
 
86c2ef8
58f1841
86c2ef8
c2fa27a
 
f3b82b8
a5fe846
58f1841
f3b82b8
 
 
 
 
 
 
 
58f1841
 
 
 
f3b82b8
58f1841
 
f3b82b8
58f1841
 
86c2ef8
58f1841
 
 
 
 
 
 
86c2ef8
58f1841
 
 
 
a5fe846
58f1841
 
 
 
 
 
 
 
 
 
86c2ef8
c2fa27a
 
86c2ef8
58f1841
c2fa27a
58f1841
c2fa27a
 
 
86c2ef8
c2fa27a
a5fe846
58f1841
 
 
c2fa27a
 
 
58f1841
c2fa27a
58f1841
 
c2fa27a
 
58f1841
 
 
 
c2fa27a
 
 
 
 
 
58f1841
86c2ef8
c2fa27a
58f1841
c2fa27a
 
 
a5fe846
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import gradio as gr
from PIL import Image
import torch
from transformers import ViTForImageClassification, ViTImageProcessor
import os
import re
from dotenv import load_dotenv
from simple_salesforce import Salesforce
from datetime import datetime
from fastapi import FastAPI, HTTPException, Security, Depends
from fastapi.security import APIKeyHeader
import base64
import io

# Load environment variables
load_dotenv()
HF_API_KEY = os.getenv("HF_API_KEY")
SF_USERNAME = os.getenv("SF_USERNAME")
SF_PASSWORD = os.getenv("SF_PASSWORD")
SF_SECURITY_TOKEN = os.getenv("SF_SECURITY_TOKEN")
SF_CONSUMER_KEY = os.getenv("SF_CONSUMER_KEY")
SF_CONSUMER_SECRET = os.getenv("SF_CONSUMER_SECRET")
API_KEY = os.getenv("API_KEY", "your-api-key-here")  # Set in Space Secrets

# Validate Salesforce credentials (for Gradio UI)
if not all([SF_USERNAME, SF_PASSWORD, SF_SECURITY_TOKEN, SF_CONSUMER_KEY, SF_CONSUMER_SECRET]):
    raise ValueError("Missing Salesforce credentials. Set SF_USERNAME, SF_PASSWORD, SF_SECURITY_TOKEN, SF_CONSUMER_KEY, and SF_CONSUMER_SECRET in environment variables.")

# Initialize Salesforce connection
try:
    sf = Salesforce(
        username=SF_USERNAME,
        password=SF_PASSWORD,
        security_token=SF_SECURITY_TOKEN,
        consumer_key=SF_CONSUMER_KEY,
        consumer_secret=SF_CONSUMER_SECRET,
        domain='login'  # Use 'test' for sandbox
    )
except Exception as e:
    print(f"Salesforce connection failed: {str(e)}")
    raise

# Load pre-trained model and processor
model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)
model.eval()

# FastAPI app for API endpoint
app = FastAPI()

# API Key authentication
api_key_header = APIKeyHeader(name="X-API-Key")
async def verify_api_key(api_key: str = Security(api_key_header)):
    if api_key != API_KEY:
        raise HTTPException(status_code=401, detail="Invalid API Key")
    return api_key

@app.post("/predict-milestone")
async def predict_milestone(payload: dict, api_key: str = Depends(verify_api_key)):
    try:
        # Validate payload
        if "image" not in payload:
            raise HTTPException(status_code=400, detail="Image field is required")
        
        # Decode base64 image
        image_data = payload["image"]
        if image_data.startswith("data:image"):
            image_data = image_data.split(",")[1]  # Remove data URI prefix
        img_bytes = base64.b64decode(image_data)
        img = Image.open(io.BytesIO(img_bytes)).convert("RGB")

        # Validate image size (max 20MB per TSD)
        img_bytes_size = len(img_bytes) / (1024 * 1024)
        if img_bytes_size > 20:
            raise HTTPException(status_code=400, detail="Image size exceeds 20MB")

        # Validate image type
        if not img.format.lower() in ["jpeg", "png"]:
            raise HTTPException(status_code=400, detail="Only JPG/PNG images are supported")

        # Preprocess image
        max_size = 1024  # Optimize for performance
        img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
        inputs = processor(images=img, return_tensors="pt")

        # Run inference (within 5 seconds per TSD)
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
            probabilities = torch.softmax(logits, dim=1)

        # Get top prediction
        predicted_idx = torch.argmax(probabilities, dim=1).item()
        milestone = model.config.id2label.get(predicted_idx, "Unknown Milestone")

        # Mock percent_complete (since model doesn't provide it)
        percent_complete = 75  # Adjust based on milestone or train a model to predict this

        return {
            "milestone": milestone,
            "percent_complete": percent_complete
        }

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")

def process_image(image, project_name):
    try:
        # Validate inputs
        if image is None:
            return "Error: Please upload an image to proceed.", "Pending", "", ""
        if not project_name:
            return "Error: Please enter a project name to proceed.", "Pending", "", ""
        if not re.match(r'^[a-zA-Z0-9\s-]+$', project_name):
            return "Error: Project name must be alphanumeric (letters, numbers, spaces, or hyphens).", "Pending", "", ""

        # Validate image size and type
        image_size_mb = os.path.getsize(image) / (1024 * 1024)
        if image_size_mb > 20:
            return "Error: Image size exceeds 20MB.", "Failure", "", ""
        if not image.lower().endswith(('.jpg', '.jpeg', '.png')):
            return "Error: Only JPG/PNG images are supported.", "Failure", "", ""

        # Preprocess image
        img = Image.open(image).convert("RGB")
        max_size = 1024
        img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
        inputs = processor(images=img, return_tensors="pt")

        # Run inference
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
            probabilities = torch.softmax(logits, dim=1)

        # Get top predictions
        top_probs, top_indices = torch.topk(probabilities, 3, dim=1)
        top_probs = top_probs[0].tolist()
        top_indices = top_indices[0].tolist()
        
        # Map indices to labels
        predicted_idx = top_indices[0]
        milestone = model.config.id2label.get(predicted_idx, "Unknown Milestone")
        
        # Format top predictions
        prediction_details = "\n".join([f"{model.config.id2label.get(idx, 'Unknown Milestone')}: {prob:.2f}" for idx, prob in zip(top_indices, top_probs)])

        # Update Salesforce record
        record = {
            "Name": project_name,
            "Current_Milestone__c": milestone,
            "Last_Updated_On__c": datetime.now().isoformat(),
            "Upload_Status__c": "Success",
            "Comments__c": f"AI Prediction: {milestone}",
            "Version__c": 1
        }

        try:
            project_name = project_name.replace("'", "''")
            query = f"SELECT Id, Version__c FROM Construction_Project__c WHERE Name = '{project_name}'"
            result = sf.query(query)
            if result["totalSize"] > 0:
                project_id = result["records"][0]["Id"]
                current_version = result["records"][0].get("Version__c", 0)
                record["Version__c"] = current_version + 1
                sf.Construction_Project__c.update(project_id, record)
            else:
                sf.Construction_Project__c.create(record)
        except Exception as e:
            return f"Error: Failed to update Salesforce - {str(e)}", "Failure", "", prediction_details

        return (
            f"Success: Milestone: {milestone}",
            "Success",
            milestone,
            prediction_details
        )

    except Exception as e:
        return f"Error: {str(e)}", "Failure", "", ""

# Gradio interface (for manual testing)
with gr.Blocks(css=".gradio-container {background-color: #f0f4f8; font-family: Arial;} .title {color: #2c3e50; font-size: 24px; text-align: center;}") as demo:
    gr.Markdown("<h1 class='title'>Construction Milestone Detector</h1>")
    project_name = gr.Textbox(label="Project Name", placeholder="Enter project name")
    image_input = gr.Image(type="filepath", label="Upload Construction Site Photo (JPG/PNG, ≤ 20MB)")
    submit_button = gr.Button("Process Image")
    output_text = gr.Textbox(label="Result")
    upload_status = gr.Textbox(label="Upload Status")
    milestone = gr.Textbox(label="Detected Milestone")
    prediction_details = gr.Textbox(label="Top Predictions")
    progress = gr.Slider(0, 100, label="Processing Progress", interactive=False, value=0)

    def update_progress():
        return 50

    def complete_progress():
        return 100

    submit_button.click(
        fn=update_progress,
        outputs=progress
    ).then(
        fn=process_image,
        inputs=[image_input, project_name],
        outputs=[output_text, upload_status, milestone, prediction_details]
    ).then(
        fn=complete_progress,
        outputs=progress
    )

# Mount FastAPI app to Gradio
demo.launch()