Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,6 +7,10 @@ import re
|
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
from simple_salesforce import Salesforce
|
| 9 |
from datetime import datetime
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# Load environment variables
|
| 12 |
load_dotenv()
|
|
@@ -16,10 +20,11 @@ SF_PASSWORD = os.getenv("SF_PASSWORD")
|
|
| 16 |
SF_SECURITY_TOKEN = os.getenv("SF_SECURITY_TOKEN")
|
| 17 |
SF_CONSUMER_KEY = os.getenv("SF_CONSUMER_KEY")
|
| 18 |
SF_CONSUMER_SECRET = os.getenv("SF_CONSUMER_SECRET")
|
|
|
|
| 19 |
|
| 20 |
-
# Validate Salesforce credentials
|
| 21 |
if not all([SF_USERNAME, SF_PASSWORD, SF_SECURITY_TOKEN, SF_CONSUMER_KEY, SF_CONSUMER_SECRET]):
|
| 22 |
-
raise ValueError("Missing Salesforce credentials.
|
| 23 |
|
| 24 |
# Initialize Salesforce connection
|
| 25 |
try:
|
|
@@ -36,15 +41,71 @@ except Exception as e:
|
|
| 36 |
raise
|
| 37 |
|
| 38 |
# Load pre-trained model and processor
|
| 39 |
-
model_name = "google/vit-base-patch16-224"
|
| 40 |
processor = ViTImageProcessor.from_pretrained(model_name)
|
| 41 |
model = ViTForImageClassification.from_pretrained(model_name)
|
| 42 |
model.eval()
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
def process_image(image, project_name):
|
| 45 |
-
"""
|
| 46 |
-
Process uploaded image, predict construction milestone, and update Salesforce.
|
| 47 |
-
"""
|
| 48 |
try:
|
| 49 |
# Validate inputs
|
| 50 |
if image is None:
|
|
@@ -63,7 +124,7 @@ def process_image(image, project_name):
|
|
| 63 |
|
| 64 |
# Preprocess image
|
| 65 |
img = Image.open(image).convert("RGB")
|
| 66 |
-
max_size = 1024
|
| 67 |
img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
|
| 68 |
inputs = processor(images=img, return_tensors="pt")
|
| 69 |
|
|
@@ -96,7 +157,7 @@ def process_image(image, project_name):
|
|
| 96 |
}
|
| 97 |
|
| 98 |
try:
|
| 99 |
-
project_name = project_name.replace("'", "''")
|
| 100 |
query = f"SELECT Id, Version__c FROM Construction_Project__c WHERE Name = '{project_name}'"
|
| 101 |
result = sf.query(query)
|
| 102 |
if result["totalSize"] > 0:
|
|
@@ -119,7 +180,7 @@ def process_image(image, project_name):
|
|
| 119 |
except Exception as e:
|
| 120 |
return f"Error: {str(e)}", "Failure", "", ""
|
| 121 |
|
| 122 |
-
# Gradio interface
|
| 123 |
with gr.Blocks(css=".gradio-container {background-color: #f0f4f8; font-family: Arial;} .title {color: #2c3e50; font-size: 24px; text-align: center;}") as demo:
|
| 124 |
gr.Markdown("<h1 class='title'>Construction Milestone Detector</h1>")
|
| 125 |
project_name = gr.Textbox(label="Project Name", placeholder="Enter project name")
|
|
@@ -149,5 +210,5 @@ with gr.Blocks(css=".gradio-container {background-color: #f0f4f8; font-family: A
|
|
| 149 |
outputs=progress
|
| 150 |
)
|
| 151 |
|
| 152 |
-
|
| 153 |
-
|
|
|
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
from simple_salesforce import Salesforce
|
| 9 |
from datetime import datetime
|
| 10 |
+
from fastapi import FastAPI, HTTPException, Security, Depends
|
| 11 |
+
from fastapi.security import APIKeyHeader
|
| 12 |
+
import base64
|
| 13 |
+
import io
|
| 14 |
|
| 15 |
# Load environment variables
|
| 16 |
load_dotenv()
|
|
|
|
| 20 |
SF_SECURITY_TOKEN = os.getenv("SF_SECURITY_TOKEN")
|
| 21 |
SF_CONSUMER_KEY = os.getenv("SF_CONSUMER_KEY")
|
| 22 |
SF_CONSUMER_SECRET = os.getenv("SF_CONSUMER_SECRET")
|
| 23 |
+
API_KEY = os.getenv("API_KEY", "your-api-key-here") # Set in Space Secrets
|
| 24 |
|
| 25 |
+
# Validate Salesforce credentials (for Gradio UI)
|
| 26 |
if not all([SF_USERNAME, SF_PASSWORD, SF_SECURITY_TOKEN, SF_CONSUMER_KEY, SF_CONSUMER_SECRET]):
|
| 27 |
+
raise ValueError("Missing Salesforce credentials. Set SF_USERNAME, SF_PASSWORD, SF_SECURITY_TOKEN, SF_CONSUMER_KEY, and SF_CONSUMER_SECRET in environment variables.")
|
| 28 |
|
| 29 |
# Initialize Salesforce connection
|
| 30 |
try:
|
|
|
|
| 41 |
raise
|
| 42 |
|
| 43 |
# Load pre-trained model and processor
|
| 44 |
+
model_name = "google/vit-base-patch16-224"
|
| 45 |
processor = ViTImageProcessor.from_pretrained(model_name)
|
| 46 |
model = ViTForImageClassification.from_pretrained(model_name)
|
| 47 |
model.eval()
|
| 48 |
|
| 49 |
+
# FastAPI app for API endpoint
|
| 50 |
+
app = FastAPI()
|
| 51 |
+
|
| 52 |
+
# API Key authentication
|
| 53 |
+
api_key_header = APIKeyHeader(name="X-API-Key")
|
| 54 |
+
async def verify_api_key(api_key: str = Security(api_key_header)):
|
| 55 |
+
if api_key != API_KEY:
|
| 56 |
+
raise HTTPException(status_code=401, detail="Invalid API Key")
|
| 57 |
+
return api_key
|
| 58 |
+
|
| 59 |
+
@app.post("/predict-milestone")
|
| 60 |
+
async def predict_milestone(payload: dict, api_key: str = Depends(verify_api_key)):
|
| 61 |
+
try:
|
| 62 |
+
# Validate payload
|
| 63 |
+
if "image" not in payload:
|
| 64 |
+
raise HTTPException(status_code=400, detail="Image field is required")
|
| 65 |
+
|
| 66 |
+
# Decode base64 image
|
| 67 |
+
image_data = payload["image"]
|
| 68 |
+
if image_data.startswith("data:image"):
|
| 69 |
+
image_data = image_data.split(",")[1] # Remove data URI prefix
|
| 70 |
+
img_bytes = base64.b64decode(image_data)
|
| 71 |
+
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
| 72 |
+
|
| 73 |
+
# Validate image size (max 20MB per TSD)
|
| 74 |
+
img_bytes_size = len(img_bytes) / (1024 * 1024)
|
| 75 |
+
if img_bytes_size > 20:
|
| 76 |
+
raise HTTPException(status_code=400, detail="Image size exceeds 20MB")
|
| 77 |
+
|
| 78 |
+
# Validate image type
|
| 79 |
+
if not img.format.lower() in ["jpeg", "png"]:
|
| 80 |
+
raise HTTPException(status_code=400, detail="Only JPG/PNG images are supported")
|
| 81 |
+
|
| 82 |
+
# Preprocess image
|
| 83 |
+
max_size = 1024 # Optimize for performance
|
| 84 |
+
img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
|
| 85 |
+
inputs = processor(images=img, return_tensors="pt")
|
| 86 |
+
|
| 87 |
+
# Run inference (within 5 seconds per TSD)
|
| 88 |
+
with torch.no_grad():
|
| 89 |
+
outputs = model(**inputs)
|
| 90 |
+
logits = outputs.logits
|
| 91 |
+
probabilities = torch.softmax(logits, dim=1)
|
| 92 |
+
|
| 93 |
+
# Get top prediction
|
| 94 |
+
predicted_idx = torch.argmax(probabilities, dim=1).item()
|
| 95 |
+
milestone = model.config.id2label.get(predicted_idx, "Unknown Milestone")
|
| 96 |
+
|
| 97 |
+
# Mock percent_complete (since model doesn't provide it)
|
| 98 |
+
percent_complete = 75 # Adjust based on milestone or train a model to predict this
|
| 99 |
+
|
| 100 |
+
return {
|
| 101 |
+
"milestone": milestone,
|
| 102 |
+
"percent_complete": percent_complete
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
except Exception as e:
|
| 106 |
+
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
| 107 |
+
|
| 108 |
def process_image(image, project_name):
|
|
|
|
|
|
|
|
|
|
| 109 |
try:
|
| 110 |
# Validate inputs
|
| 111 |
if image is None:
|
|
|
|
| 124 |
|
| 125 |
# Preprocess image
|
| 126 |
img = Image.open(image).convert("RGB")
|
| 127 |
+
max_size = 1024
|
| 128 |
img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
|
| 129 |
inputs = processor(images=img, return_tensors="pt")
|
| 130 |
|
|
|
|
| 157 |
}
|
| 158 |
|
| 159 |
try:
|
| 160 |
+
project_name = project_name.replace("'", "''")
|
| 161 |
query = f"SELECT Id, Version__c FROM Construction_Project__c WHERE Name = '{project_name}'"
|
| 162 |
result = sf.query(query)
|
| 163 |
if result["totalSize"] > 0:
|
|
|
|
| 180 |
except Exception as e:
|
| 181 |
return f"Error: {str(e)}", "Failure", "", ""
|
| 182 |
|
| 183 |
+
# Gradio interface (for manual testing)
|
| 184 |
with gr.Blocks(css=".gradio-container {background-color: #f0f4f8; font-family: Arial;} .title {color: #2c3e50; font-size: 24px; text-align: center;}") as demo:
|
| 185 |
gr.Markdown("<h1 class='title'>Construction Milestone Detector</h1>")
|
| 186 |
project_name = gr.Textbox(label="Project Name", placeholder="Enter project name")
|
|
|
|
| 210 |
outputs=progress
|
| 211 |
)
|
| 212 |
|
| 213 |
+
# Mount FastAPI app to Gradio
|
| 214 |
+
demo.launch()
|