import gradio as gr from PIL import Image import torch from transformers import ViTForImageClassification, ViTImageProcessor import os import re from dotenv import load_dotenv from simple_salesforce import Salesforce from datetime import datetime from fastapi import FastAPI, HTTPException, Security, Depends from fastapi.security import APIKeyHeader import base64 import io # Load environment variables load_dotenv() HF_API_KEY = os.getenv("HF_API_KEY") SF_USERNAME = os.getenv("SF_USERNAME") SF_PASSWORD = os.getenv("SF_PASSWORD") SF_SECURITY_TOKEN = os.getenv("SF_SECURITY_TOKEN") SF_CONSUMER_KEY = os.getenv("SF_CONSUMER_KEY") SF_CONSUMER_SECRET = os.getenv("SF_CONSUMER_SECRET") API_KEY = os.getenv("API_KEY", "your-api-key-here") # Set in Space Secrets # Validate Salesforce credentials (for Gradio UI) if not all([SF_USERNAME, SF_PASSWORD, SF_SECURITY_TOKEN, SF_CONSUMER_KEY, SF_CONSUMER_SECRET]): raise ValueError("Missing Salesforce credentials. Set SF_USERNAME, SF_PASSWORD, SF_SECURITY_TOKEN, SF_CONSUMER_KEY, and SF_CONSUMER_SECRET in environment variables.") # Initialize Salesforce connection try: sf = Salesforce( username=SF_USERNAME, password=SF_PASSWORD, security_token=SF_SECURITY_TOKEN, consumer_key=SF_CONSUMER_KEY, consumer_secret=SF_CONSUMER_SECRET, domain='login' # Use 'test' for sandbox ) except Exception as e: print(f"Salesforce connection failed: {str(e)}") raise # Load pre-trained model and processor model_name = "google/vit-base-patch16-224" processor = ViTImageProcessor.from_pretrained(model_name) model = ViTForImageClassification.from_pretrained(model_name) model.eval() # FastAPI app for API endpoint app = FastAPI() # API Key authentication api_key_header = APIKeyHeader(name="X-API-Key") async def verify_api_key(api_key: str = Security(api_key_header)): if api_key != API_KEY: raise HTTPException(status_code=401, detail="Invalid API Key") return api_key @app.post("/predict-milestone") async def predict_milestone(payload: dict, api_key: str = Depends(verify_api_key)): try: # Validate payload if "image" not in payload: raise HTTPException(status_code=400, detail="Image field is required") # Decode base64 image image_data = payload["image"] if image_data.startswith("data:image"): image_data = image_data.split(",")[1] # Remove data URI prefix img_bytes = base64.b64decode(image_data) img = Image.open(io.BytesIO(img_bytes)).convert("RGB") # Validate image size (max 20MB per TSD) img_bytes_size = len(img_bytes) / (1024 * 1024) if img_bytes_size > 20: raise HTTPException(status_code=400, detail="Image size exceeds 20MB") # Validate image type if not img.format.lower() in ["jpeg", "png"]: raise HTTPException(status_code=400, detail="Only JPG/PNG images are supported") # Preprocess image max_size = 1024 # Optimize for performance img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) inputs = processor(images=img, return_tensors="pt") # Run inference (within 5 seconds per TSD) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probabilities = torch.softmax(logits, dim=1) # Get top prediction predicted_idx = torch.argmax(probabilities, dim=1).item() milestone = model.config.id2label.get(predicted_idx, "Unknown Milestone") # Mock percent_complete (since model doesn't provide it) percent_complete = 75 # Adjust based on milestone or train a model to predict this return { "milestone": milestone, "percent_complete": percent_complete } except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") def process_image(image, project_name): try: # Validate inputs if image is None: return "Error: Please upload an image to proceed.", "Pending", "", "" if not project_name: return "Error: Please enter a project name to proceed.", "Pending", "", "" if not re.match(r'^[a-zA-Z0-9\s-]+$', project_name): return "Error: Project name must be alphanumeric (letters, numbers, spaces, or hyphens).", "Pending", "", "" # Validate image size and type image_size_mb = os.path.getsize(image) / (1024 * 1024) if image_size_mb > 20: return "Error: Image size exceeds 20MB.", "Failure", "", "" if not image.lower().endswith(('.jpg', '.jpeg', '.png')): return "Error: Only JPG/PNG images are supported.", "Failure", "", "" # Preprocess image img = Image.open(image).convert("RGB") max_size = 1024 img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) inputs = processor(images=img, return_tensors="pt") # Run inference with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probabilities = torch.softmax(logits, dim=1) # Get top predictions top_probs, top_indices = torch.topk(probabilities, 3, dim=1) top_probs = top_probs[0].tolist() top_indices = top_indices[0].tolist() # Map indices to labels predicted_idx = top_indices[0] milestone = model.config.id2label.get(predicted_idx, "Unknown Milestone") # Format top predictions prediction_details = "\n".join([f"{model.config.id2label.get(idx, 'Unknown Milestone')}: {prob:.2f}" for idx, prob in zip(top_indices, top_probs)]) # Update Salesforce record record = { "Name": project_name, "Current_Milestone__c": milestone, "Last_Updated_On__c": datetime.now().isoformat(), "Upload_Status__c": "Success", "Comments__c": f"AI Prediction: {milestone}", "Version__c": 1 } try: project_name = project_name.replace("'", "''") query = f"SELECT Id, Version__c FROM Construction_Project__c WHERE Name = '{project_name}'" result = sf.query(query) if result["totalSize"] > 0: project_id = result["records"][0]["Id"] current_version = result["records"][0].get("Version__c", 0) record["Version__c"] = current_version + 1 sf.Construction_Project__c.update(project_id, record) else: sf.Construction_Project__c.create(record) except Exception as e: return f"Error: Failed to update Salesforce - {str(e)}", "Failure", "", prediction_details return ( f"Success: Milestone: {milestone}", "Success", milestone, prediction_details ) except Exception as e: return f"Error: {str(e)}", "Failure", "", "" # Gradio interface (for manual testing) with gr.Blocks(css=".gradio-container {background-color: #f0f4f8; font-family: Arial;} .title {color: #2c3e50; font-size: 24px; text-align: center;}") as demo: gr.Markdown("

Construction Milestone Detector

") project_name = gr.Textbox(label="Project Name", placeholder="Enter project name") image_input = gr.Image(type="filepath", label="Upload Construction Site Photo (JPG/PNG, ≤ 20MB)") submit_button = gr.Button("Process Image") output_text = gr.Textbox(label="Result") upload_status = gr.Textbox(label="Upload Status") milestone = gr.Textbox(label="Detected Milestone") prediction_details = gr.Textbox(label="Top Predictions") progress = gr.Slider(0, 100, label="Processing Progress", interactive=False, value=0) def update_progress(): return 50 def complete_progress(): return 100 submit_button.click( fn=update_progress, outputs=progress ).then( fn=process_image, inputs=[image_input, project_name], outputs=[output_text, upload_status, milestone, prediction_details] ).then( fn=complete_progress, outputs=progress ) # Mount FastAPI app to Gradio demo.launch()