Files changed (1) hide show
  1. app.py +194 -194
app.py CHANGED
@@ -1,214 +1,214 @@
1
  import gradio as gr
2
- from PIL import Image
3
- import torch
4
- from transformers import ViTForImageClassification, ViTImageProcessor
5
  import os
6
- import re
7
- from dotenv import load_dotenv
 
8
  from simple_salesforce import Salesforce
9
- from datetime import datetime
10
- from fastapi import FastAPI, HTTPException, Security, Depends
11
- from fastapi.security import APIKeyHeader
12
- import base64
13
  import io
 
 
14
 
15
- # Load environment variables
16
  load_dotenv()
17
- HF_API_KEY = os.getenv("HF_API_KEY")
18
- SF_USERNAME = os.getenv("SF_USERNAME")
19
- SF_PASSWORD = os.getenv("SF_PASSWORD")
20
- SF_SECURITY_TOKEN = os.getenv("SF_SECURITY_TOKEN")
21
- SF_CONSUMER_KEY = os.getenv("SF_CONSUMER_KEY")
22
- SF_CONSUMER_SECRET = os.getenv("SF_CONSUMER_SECRET")
23
- API_KEY = os.getenv("API_KEY", "your-api-key-here") # Set in Space Secrets
24
-
25
- # Validate Salesforce credentials (for Gradio UI)
26
- if not all([SF_USERNAME, SF_PASSWORD, SF_SECURITY_TOKEN, SF_CONSUMER_KEY, SF_CONSUMER_SECRET]):
27
- raise ValueError("Missing Salesforce credentials. Set SF_USERNAME, SF_PASSWORD, SF_SECURITY_TOKEN, SF_CONSUMER_KEY, and SF_CONSUMER_SECRET in environment variables.")
28
-
29
- # Initialize Salesforce connection
30
- try:
31
- sf = Salesforce(
32
- username=SF_USERNAME,
33
- password=SF_PASSWORD,
34
- security_token=SF_SECURITY_TOKEN,
35
- consumer_key=SF_CONSUMER_KEY,
36
- consumer_secret=SF_CONSUMER_SECRET,
37
- domain='login' # Use 'test' for sandbox
38
- )
39
- except Exception as e:
40
- print(f"Salesforce connection failed: {str(e)}")
41
- raise
42
-
43
- # Load pre-trained model and processor
44
- model_name = "google/vit-base-patch16-224"
45
- processor = ViTImageProcessor.from_pretrained(model_name)
46
- model = ViTForImageClassification.from_pretrained(model_name)
47
- model.eval()
48
-
49
- # FastAPI app for API endpoint
50
- app = FastAPI()
51
 
52
- # API Key authentication
53
- api_key_header = APIKeyHeader(name="X-API-Key")
54
- async def verify_api_key(api_key: str = Security(api_key_header)):
55
- if api_key != API_KEY:
56
- raise HTTPException(status_code=401, detail="Invalid API Key")
57
- return api_key
58
-
59
- @app.post("/predict-milestone")
60
- async def predict_milestone(payload: dict, api_key: str = Depends(verify_api_key)):
 
 
 
 
61
  try:
62
- # Validate payload
63
- if "image" not in payload:
64
- raise HTTPException(status_code=400, detail="Image field is required")
65
 
66
- # Decode base64 image
67
- image_data = payload["image"]
68
- if image_data.startswith("data:image"):
69
- image_data = image_data.split(",")[1] # Remove data URI prefix
70
- img_bytes = base64.b64decode(image_data)
71
- img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
72
-
73
- # Validate image size (max 20MB per TSD)
74
- img_bytes_size = len(img_bytes) / (1024 * 1024)
75
- if img_bytes_size > 20:
76
- raise HTTPException(status_code=400, detail="Image size exceeds 20MB")
77
-
78
- # Validate image type
79
- if not img.format.lower() in ["jpeg", "png"]:
80
- raise HTTPException(status_code=400, detail="Only JPG/PNG images are supported")
81
-
82
- # Preprocess image
83
- max_size = 1024 # Optimize for performance
84
- img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
85
- inputs = processor(images=img, return_tensors="pt")
86
-
87
- # Run inference (within 5 seconds per TSD)
88
- with torch.no_grad():
89
- outputs = model(**inputs)
90
- logits = outputs.logits
91
- probabilities = torch.softmax(logits, dim=1)
92
-
93
- # Get top prediction
94
- predicted_idx = torch.argmax(probabilities, dim=1).item()
95
- milestone = model.config.id2label.get(predicted_idx, "Unknown Milestone")
96
-
97
- # Mock percent_complete (since model doesn't provide it)
98
- percent_complete = 75 # Adjust based on milestone or train a model to predict this
99
-
100
- return {
101
- "milestone": milestone,
102
- "percent_complete": percent_complete
103
  }
104
-
 
 
 
 
 
 
 
 
 
 
 
 
105
  except Exception as e:
106
- raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
107
 
108
- def process_image(image, project_name):
 
109
  try:
110
- # Validate inputs
111
- if image is None:
112
- return "Error: Please upload an image to proceed.", "Pending", "", ""
113
- if not project_name:
114
- return "Error: Please enter a project name to proceed.", "Pending", "", ""
115
- if not re.match(r'^[a-zA-Z0-9\s-]+$', project_name):
116
- return "Error: Project name must be alphanumeric (letters, numbers, spaces, or hyphens).", "Pending", "", ""
117
-
118
- # Validate image size and type
119
- image_size_mb = os.path.getsize(image) / (1024 * 1024)
120
- if image_size_mb > 20:
121
- return "Error: Image size exceeds 20MB.", "Failure", "", ""
122
- if not image.lower().endswith(('.jpg', '.jpeg', '.png')):
123
- return "Error: Only JPG/PNG images are supported.", "Failure", "", ""
124
-
125
- # Preprocess image
126
- img = Image.open(image).convert("RGB")
127
- max_size = 1024
128
- img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
129
- inputs = processor(images=img, return_tensors="pt")
130
-
131
- # Run inference
132
- with torch.no_grad():
133
- outputs = model(**inputs)
134
- logits = outputs.logits
135
- probabilities = torch.softmax(logits, dim=1)
136
 
137
- # Get top predictions
138
- top_probs, top_indices = torch.topk(probabilities, 3, dim=1)
139
- top_probs = top_probs[0].tolist()
140
- top_indices = top_indices[0].tolist()
 
 
141
 
142
- # Map indices to labels
143
- predicted_idx = top_indices[0]
144
- milestone = model.config.id2label.get(predicted_idx, "Unknown Milestone")
145
 
146
- # Format top predictions
147
- prediction_details = "\n".join([f"{model.config.id2label.get(idx, 'Unknown Milestone')}: {prob:.2f}" for idx, prob in zip(top_indices, top_probs)])
148
-
149
- # Update Salesforce record
150
- record = {
151
- "Name": project_name,
152
- "Current_Milestone__c": milestone,
153
- "Last_Updated_On__c": datetime.now().isoformat(),
154
- "Upload_Status__c": "Success",
155
- "Comments__c": f"AI Prediction: {milestone}",
156
- "Version__c": 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  }
158
-
159
- try:
160
- project_name = project_name.replace("'", "''")
161
- query = f"SELECT Id, Version__c FROM Construction_Project__c WHERE Name = '{project_name}'"
162
- result = sf.query(query)
163
- if result["totalSize"] > 0:
164
- project_id = result["records"][0]["Id"]
165
- current_version = result["records"][0].get("Version__c", 0)
166
- record["Version__c"] = current_version + 1
167
- sf.Construction_Project__c.update(project_id, record)
168
- else:
169
- sf.Construction_Project__c.create(record)
170
- except Exception as e:
171
- return f"Error: Failed to update Salesforce - {str(e)}", "Failure", "", prediction_details
172
-
173
- return (
174
- f"Success: Milestone: {milestone}",
175
- "Success",
176
- milestone,
177
- prediction_details
178
  )
179
-
180
  except Exception as e:
181
- return f"Error: {str(e)}", "Failure", "", ""
182
-
183
- # Gradio interface (for manual testing)
184
- with gr.Blocks(css=".gradio-container {background-color: #f0f4f8; font-family: Arial;} .title {color: #2c3e50; font-size: 24px; text-align: center;}") as demo:
185
- gr.Markdown("<h1 class='title'>Construction Milestone Detector</h1>")
186
- project_name = gr.Textbox(label="Project Name", placeholder="Enter project name")
187
- image_input = gr.Image(type="filepath", label="Upload Construction Site Photo (JPG/PNG, ≤ 20MB)")
188
- submit_button = gr.Button("Process Image")
189
- output_text = gr.Textbox(label="Result")
190
- upload_status = gr.Textbox(label="Upload Status")
191
- milestone = gr.Textbox(label="Detected Milestone")
192
- prediction_details = gr.Textbox(label="Top Predictions")
193
- progress = gr.Slider(0, 100, label="Processing Progress", interactive=False, value=0)
194
-
195
- def update_progress():
196
- return 50
197
-
198
- def complete_progress():
199
- return 100
200
-
201
- submit_button.click(
202
- fn=update_progress,
203
- outputs=progress
204
- ).then(
205
- fn=process_image,
206
- inputs=[image_input, project_name],
207
- outputs=[output_text, upload_status, milestone, prediction_details]
208
- ).then(
209
- fn=complete_progress,
210
- outputs=progress
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  )
212
-
213
- # Mount FastAPI app to Gradio
214
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import requests
 
 
3
  import os
4
+ from PIL import Image
5
+ import numpy as np
6
+ from transformers import pipeline
7
  from simple_salesforce import Salesforce
 
 
 
 
8
  import io
9
+ import time
10
+ from dotenv import load_dotenv
11
 
12
+ # Load environment variables from .env file
13
  load_dotenv()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ # Function to validate photo size (< 20MB)
16
+ def validate_photo_size(image_file):
17
+ max_size_mb = 20
18
+ if isinstance(image_file, Image.Image):
19
+ # Convert PIL Image to bytes for size check
20
+ img_byte_arr = io.BytesIO()
21
+ image_file.save(img_byte_arr, format='JPEG')
22
+ file_size_mb = img_byte_arr.tell() / (1024 * 1024) # Convert bytes to MB
23
+ return file_size_mb <= max_size_mb, None
24
+ return False, "Invalid image format"
25
+
26
+ # Function to process image with AI and predict milestone
27
+ def predict_milestone(image):
28
  try:
29
+ # Simulate AI processing time (ensure < 5 seconds)
30
+ start_time = time.time()
 
31
 
32
+ # Process image with Hugging Face model
33
+ model = pipeline("image-classification", model="microsoft/resnet-50")
34
+ predictions = model(image)
35
+
36
+ # Placeholder logic: Map model output to construction milestones
37
+ milestone = predictions[0]["label"] # Example: "positive" -> "Walls Erected"
38
+ confidence = predictions[0]["score"]
39
+
40
+ # Map model output to construction milestones (customize this)
41
+ milestone_map = {
42
+ "positive": "Walls Erected",
43
+ "negative": "Foundation Completed",
44
+ # Add more mappings based on your model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  }
46
+ completion_map = {
47
+ "positive": 60.00, # Example: Walls = 60% complete
48
+ "negative": 20.00, # Example: Foundation = 20% complete
49
+ }
50
+
51
+ predicted_milestone = milestone_map.get(milestone, "Unknown Milestone")
52
+ completion_percentage = completion_map.get(milestone, 0.00)
53
+
54
+ processing_time = time.time() - start_time
55
+ if processing_time > 5:
56
+ return None, None, "AI took too long to process (> 5 seconds)."
57
+
58
+ return predicted_milestone, completion_percentage, None
59
  except Exception as e:
60
+ return None, None, f"AI failed to process the image: {str(e)}"
61
 
62
+ # Function to upload image to Salesforce and get a URL
63
+ def upload_image_to_salesforce(image, project_name):
64
  try:
65
+ # Placeholder: Simulate uploading image to Salesforce ContentVersion
66
+ image_url = f"https://your-salesforce-instance.com/file/{project_name}.jpg" # Simulated URL
67
+ return image_url, None
68
+ except Exception as e:
69
+ return None, f"Failed to upload image to Salesforce: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ # Function to update Salesforce Construction_Project__c object and fetch fields
72
+ def update_salesforce_record(sf, project_name, milestone, percentage, image_url, status, comments):
73
+ try:
74
+ # Query to check if the project exists
75
+ query = f"SELECT Id FROM Construction_Project__c WHERE Name = '{project_name}'"
76
+ result = sf.query(query)
77
 
78
+ if result['totalSize'] == 0:
79
+ return None, f"No project found with Name: {project_name}"
 
80
 
81
+ record_id = result['records'][0]['Id']
82
+
83
+ # Update the record
84
+ sf.Construction_Project__c.update(record_id, {
85
+ 'Current_Milestone__c': milestone,
86
+ 'Completion_Percentage__c': percentage,
87
+ 'Last_Updated_Image__c': image_url,
88
+ 'Last_Updated_On__c': time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime()),
89
+ 'Upload_Status__c': status,
90
+ 'Comments__c': comments
91
+ })
92
+
93
+ # Fetch the updated record to get the specified fields
94
+ updated_query = f"SELECT Current_Milestone__c, Last_Updated_Image__c, Last_Updated_On__c, Upload_Status__c FROM Construction_Project__c WHERE Id = '{record_id}'"
95
+ updated_result = sf.query(updated_query)
96
+
97
+ if updated_result['totalSize'] == 0:
98
+ return None, "Failed to retrieve updated record."
99
+
100
+ record = updated_result['records'][0]
101
+ fields_output = {
102
+ 'Current_Milestone__c': record.get('Current_Milestone__c', 'N/A'),
103
+ 'Last_Updated_Image__c': record.get('Last_Updated_Image__c', 'N/A'),
104
+ 'Last_Updated_On__c': record.get('Last_Updated_On__c', 'N/A'),
105
+ 'Upload_Status__c': record.get('Upload_Status__c', 'N/A')
106
  }
107
+ return fields_output, None
108
+ except Exception as e:
109
+ return None, f"Failed to update Salesforce: {str(e)}"
110
+
111
+ # Main Gradio function
112
+ def process_construction_photo(project_name, image):
113
+ if not project_name or not image:
114
+ return None, "Please provide a project name and upload a photo."
115
+
116
+ # Connect to Salesforce
117
+ try:
118
+ sf = Salesforce(
119
+ username=os.getenv('SALESFORCE_USERNAME'),
120
+ password=os.getenv('SALESFORCE_PASSWORD'),
121
+ security_token=os.getenv('SALESFORCE_SECURITY_TOKEN'),
122
+ domain=os.getenv('SALESFORCE_DOMAIN')
 
 
 
 
123
  )
 
124
  except Exception as e:
125
+ return None, f"Failed to connect to Salesforce: {str(e)}"
126
+
127
+ # Validate photo size
128
+ is_valid, error = validate_photo_size(image)
129
+ if not is_valid:
130
+ return None, error or "Photo is too large! Please upload a photo smaller than 20MB."
131
+
132
+ # Process the image with AI
133
+ milestone, percentage, error = predict_milestone(image)
134
+
135
+ if error:
136
+ fields, error_message = update_salesforce_record(
137
+ sf=sf,
138
+ project_name=project_name,
139
+ milestone=None,
140
+ percentage=0.00,
141
+ image_url=None,
142
+ status="Failure",
143
+ comments=error
144
+ )
145
+ error_text = f"AI Error: {error}"
146
+ if error_message:
147
+ error_text += f"\nSalesforce Error: {error_message}"
148
+ if fields:
149
+ error_text += "\nUpdated Salesforce Fields:\n"
150
+ for field, value in fields.items():
151
+ error_text += f"{field}: {value}\n"
152
+ return None, error_text
153
+
154
+ # Upload image to Salesforce
155
+ image_url, upload_error = upload_image_to_salesforce(image, project_name)
156
+
157
+ if upload_error:
158
+ fields, error_message = update_salesforce_record(
159
+ sf=sf,
160
+ project_name=project_name,
161
+ milestone=milestone,
162
+ percentage=percentage,
163
+ image_url=None,
164
+ status="Failure",
165
+ comments=upload_error
166
+ )
167
+ error_text = f"Upload Error: {upload_error}"
168
+ if error_message:
169
+ error_text += f"\nSalesforce Error: {error_message}"
170
+ if fields:
171
+ error_text += "\nUpdated Salesforce Fields:\n"
172
+ for field, value in fields.items():
173
+ error_text += f"{field}: {value}\n"
174
+ return None, error_text
175
+
176
+ # Update Salesforce with success
177
+ fields, error_message = update_salesforce_record(
178
+ sf=sf,
179
+ project_name=project_name,
180
+ milestone=milestone,
181
+ percentage=percentage,
182
+ image_url=image_url,
183
+ status="Success",
184
+ comments="Photo processed successfully"
185
  )
186
+
187
+ if error_message:
188
+ return None, f"Salesforce Error: {error_message}"
189
+
190
+ # Prepare output with AI results and Salesforce fields
191
+ result_text = f"Success! Milestone: {milestone}, Completion: {percentage}%\nProgress saved to Salesforce!\n\nSalesforce Fields:\n"
192
+ for field, value in fields.items():
193
+ result_text += f"{field}: {value}\n"
194
+
195
+ return image, result_text
196
+
197
+ # Gradio interface
198
+ iface = gr.Interface(
199
+ fn=process_construction_photo,
200
+ inputs=[
201
+ gr.Textbox(label="Project Name (e.g., Sunshine Apartments)", placeholder="Sunshine Apartments"),
202
+ gr.Image(type="pil", label="Upload a Construction Photo")
203
+ ],
204
+ outputs=[
205
+ gr.Image(label="Uploaded Photo"),
206
+ gr.Textbox(label="Result")
207
+ ],
208
+ title="Construction Project Progress Tracker",
209
+ description="Upload a photo of your construction site, and the AI will tell you the progress!"
210
+ )
211
+
212
+ # Launch the Gradio app
213
+ if __name__ == "__main__":
214
+ iface.launch()