Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| from transformers import ViTForImageClassification, ViTImageProcessor | |
| import os | |
| import re | |
| from dotenv import load_dotenv | |
| from simple_salesforce import Salesforce | |
| from datetime import datetime | |
| from fastapi import FastAPI, HTTPException, Security, Depends | |
| from fastapi.security import APIKeyHeader | |
| import base64 | |
| import io | |
| # Load environment variables | |
| load_dotenv() | |
| HF_API_KEY = os.getenv("HF_API_KEY") | |
| SF_USERNAME = os.getenv("SF_USERNAME") | |
| SF_PASSWORD = os.getenv("SF_PASSWORD") | |
| SF_SECURITY_TOKEN = os.getenv("SF_SECURITY_TOKEN") | |
| SF_CONSUMER_KEY = os.getenv("SF_CONSUMER_KEY") | |
| SF_CONSUMER_SECRET = os.getenv("SF_CONSUMER_SECRET") | |
| API_KEY = os.getenv("API_KEY", "your-api-key-here") # Set in Space Secrets | |
| # Validate Salesforce credentials (for Gradio UI) | |
| if not all([SF_USERNAME, SF_PASSWORD, SF_SECURITY_TOKEN, SF_CONSUMER_KEY, SF_CONSUMER_SECRET]): | |
| raise ValueError("Missing Salesforce credentials. Set SF_USERNAME, SF_PASSWORD, SF_SECURITY_TOKEN, SF_CONSUMER_KEY, and SF_CONSUMER_SECRET in environment variables.") | |
| # Initialize Salesforce connection | |
| try: | |
| sf = Salesforce( | |
| username=SF_USERNAME, | |
| password=SF_PASSWORD, | |
| security_token=SF_SECURITY_TOKEN, | |
| consumer_key=SF_CONSUMER_KEY, | |
| consumer_secret=SF_CONSUMER_SECRET, | |
| domain='login' # Use 'test' for sandbox | |
| ) | |
| except Exception as e: | |
| print(f"Salesforce connection failed: {str(e)}") | |
| raise | |
| # Load pre-trained model and processor | |
| model_name = "google/vit-base-patch16-224" | |
| processor = ViTImageProcessor.from_pretrained(model_name) | |
| model = ViTForImageClassification.from_pretrained(model_name) | |
| model.eval() | |
| # FastAPI app for API endpoint | |
| app = FastAPI() | |
| # API Key authentication | |
| api_key_header = APIKeyHeader(name="X-API-Key") | |
| async def verify_api_key(api_key: str = Security(api_key_header)): | |
| if api_key != API_KEY: | |
| raise HTTPException(status_code=401, detail="Invalid API Key") | |
| return api_key | |
| async def predict_milestone(payload: dict, api_key: str = Depends(verify_api_key)): | |
| try: | |
| # Validate payload | |
| if "image" not in payload: | |
| raise HTTPException(status_code=400, detail="Image field is required") | |
| # Decode base64 image | |
| image_data = payload["image"] | |
| if image_data.startswith("data:image"): | |
| image_data = image_data.split(",")[1] # Remove data URI prefix | |
| img_bytes = base64.b64decode(image_data) | |
| img = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
| # Validate image size (max 20MB per TSD) | |
| img_bytes_size = len(img_bytes) / (1024 * 1024) | |
| if img_bytes_size > 20: | |
| raise HTTPException(status_code=400, detail="Image size exceeds 20MB") | |
| # Validate image type | |
| if not img.format.lower() in ["jpeg", "png"]: | |
| raise HTTPException(status_code=400, detail="Only JPG/PNG images are supported") | |
| # Preprocess image | |
| max_size = 1024 # Optimize for performance | |
| img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) | |
| inputs = processor(images=img, return_tensors="pt") | |
| # Run inference (within 5 seconds per TSD) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probabilities = torch.softmax(logits, dim=1) | |
| # Get top prediction | |
| predicted_idx = torch.argmax(probabilities, dim=1).item() | |
| milestone = model.config.id2label.get(predicted_idx, "Unknown Milestone") | |
| # Mock percent_complete (since model doesn't provide it) | |
| percent_complete = 75 # Adjust based on milestone or train a model to predict this | |
| return { | |
| "milestone": milestone, | |
| "percent_complete": percent_complete | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") | |
| def process_image(image, project_name): | |
| try: | |
| # Validate inputs | |
| if image is None: | |
| return "Error: Please upload an image to proceed.", "Pending", "", "" | |
| if not project_name: | |
| return "Error: Please enter a project name to proceed.", "Pending", "", "" | |
| if not re.match(r'^[a-zA-Z0-9\s-]+$', project_name): | |
| return "Error: Project name must be alphanumeric (letters, numbers, spaces, or hyphens).", "Pending", "", "" | |
| # Validate image size and type | |
| image_size_mb = os.path.getsize(image) / (1024 * 1024) | |
| if image_size_mb > 20: | |
| return "Error: Image size exceeds 20MB.", "Failure", "", "" | |
| if not image.lower().endswith(('.jpg', '.jpeg', '.png')): | |
| return "Error: Only JPG/PNG images are supported.", "Failure", "", "" | |
| # Preprocess image | |
| img = Image.open(image).convert("RGB") | |
| max_size = 1024 | |
| img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) | |
| inputs = processor(images=img, return_tensors="pt") | |
| # Run inference | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probabilities = torch.softmax(logits, dim=1) | |
| # Get top predictions | |
| top_probs, top_indices = torch.topk(probabilities, 3, dim=1) | |
| top_probs = top_probs[0].tolist() | |
| top_indices = top_indices[0].tolist() | |
| # Map indices to labels | |
| predicted_idx = top_indices[0] | |
| milestone = model.config.id2label.get(predicted_idx, "Unknown Milestone") | |
| # Format top predictions | |
| prediction_details = "\n".join([f"{model.config.id2label.get(idx, 'Unknown Milestone')}: {prob:.2f}" for idx, prob in zip(top_indices, top_probs)]) | |
| # Update Salesforce record | |
| record = { | |
| "Name": project_name, | |
| "Current_Milestone__c": milestone, | |
| "Last_Updated_On__c": datetime.now().isoformat(), | |
| "Upload_Status__c": "Success", | |
| "Comments__c": f"AI Prediction: {milestone}", | |
| "Version__c": 1 | |
| } | |
| try: | |
| project_name = project_name.replace("'", "''") | |
| query = f"SELECT Id, Version__c FROM Construction_Project__c WHERE Name = '{project_name}'" | |
| result = sf.query(query) | |
| if result["totalSize"] > 0: | |
| project_id = result["records"][0]["Id"] | |
| current_version = result["records"][0].get("Version__c", 0) | |
| record["Version__c"] = current_version + 1 | |
| sf.Construction_Project__c.update(project_id, record) | |
| else: | |
| sf.Construction_Project__c.create(record) | |
| except Exception as e: | |
| return f"Error: Failed to update Salesforce - {str(e)}", "Failure", "", prediction_details | |
| return ( | |
| f"Success: Milestone: {milestone}", | |
| "Success", | |
| milestone, | |
| prediction_details | |
| ) | |
| except Exception as e: | |
| return f"Error: {str(e)}", "Failure", "", "" | |
| # Gradio interface (for manual testing) | |
| with gr.Blocks(css=".gradio-container {background-color: #f0f4f8; font-family: Arial;} .title {color: #2c3e50; font-size: 24px; text-align: center;}") as demo: | |
| gr.Markdown("<h1 class='title'>Construction Milestone Detector</h1>") | |
| project_name = gr.Textbox(label="Project Name", placeholder="Enter project name") | |
| image_input = gr.Image(type="filepath", label="Upload Construction Site Photo (JPG/PNG, ≤ 20MB)") | |
| submit_button = gr.Button("Process Image") | |
| output_text = gr.Textbox(label="Result") | |
| upload_status = gr.Textbox(label="Upload Status") | |
| milestone = gr.Textbox(label="Detected Milestone") | |
| prediction_details = gr.Textbox(label="Top Predictions") | |
| progress = gr.Slider(0, 100, label="Processing Progress", interactive=False, value=0) | |
| def update_progress(): | |
| return 50 | |
| def complete_progress(): | |
| return 100 | |
| submit_button.click( | |
| fn=update_progress, | |
| outputs=progress | |
| ).then( | |
| fn=process_image, | |
| inputs=[image_input, project_name], | |
| outputs=[output_text, upload_status, milestone, prediction_details] | |
| ).then( | |
| fn=complete_progress, | |
| outputs=progress | |
| ) | |
| # Mount FastAPI app to Gradio | |
| demo.launch() | |