nasreshsuguru's picture
Update app.py
a5fe846 verified
import gradio as gr
from PIL import Image
import torch
from transformers import ViTForImageClassification, ViTImageProcessor
import os
import re
from dotenv import load_dotenv
from simple_salesforce import Salesforce
from datetime import datetime
from fastapi import FastAPI, HTTPException, Security, Depends
from fastapi.security import APIKeyHeader
import base64
import io
# Load environment variables
load_dotenv()
HF_API_KEY = os.getenv("HF_API_KEY")
SF_USERNAME = os.getenv("SF_USERNAME")
SF_PASSWORD = os.getenv("SF_PASSWORD")
SF_SECURITY_TOKEN = os.getenv("SF_SECURITY_TOKEN")
SF_CONSUMER_KEY = os.getenv("SF_CONSUMER_KEY")
SF_CONSUMER_SECRET = os.getenv("SF_CONSUMER_SECRET")
API_KEY = os.getenv("API_KEY", "your-api-key-here") # Set in Space Secrets
# Validate Salesforce credentials (for Gradio UI)
if not all([SF_USERNAME, SF_PASSWORD, SF_SECURITY_TOKEN, SF_CONSUMER_KEY, SF_CONSUMER_SECRET]):
raise ValueError("Missing Salesforce credentials. Set SF_USERNAME, SF_PASSWORD, SF_SECURITY_TOKEN, SF_CONSUMER_KEY, and SF_CONSUMER_SECRET in environment variables.")
# Initialize Salesforce connection
try:
sf = Salesforce(
username=SF_USERNAME,
password=SF_PASSWORD,
security_token=SF_SECURITY_TOKEN,
consumer_key=SF_CONSUMER_KEY,
consumer_secret=SF_CONSUMER_SECRET,
domain='login' # Use 'test' for sandbox
)
except Exception as e:
print(f"Salesforce connection failed: {str(e)}")
raise
# Load pre-trained model and processor
model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)
model.eval()
# FastAPI app for API endpoint
app = FastAPI()
# API Key authentication
api_key_header = APIKeyHeader(name="X-API-Key")
async def verify_api_key(api_key: str = Security(api_key_header)):
if api_key != API_KEY:
raise HTTPException(status_code=401, detail="Invalid API Key")
return api_key
@app.post("/predict-milestone")
async def predict_milestone(payload: dict, api_key: str = Depends(verify_api_key)):
try:
# Validate payload
if "image" not in payload:
raise HTTPException(status_code=400, detail="Image field is required")
# Decode base64 image
image_data = payload["image"]
if image_data.startswith("data:image"):
image_data = image_data.split(",")[1] # Remove data URI prefix
img_bytes = base64.b64decode(image_data)
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
# Validate image size (max 20MB per TSD)
img_bytes_size = len(img_bytes) / (1024 * 1024)
if img_bytes_size > 20:
raise HTTPException(status_code=400, detail="Image size exceeds 20MB")
# Validate image type
if not img.format.lower() in ["jpeg", "png"]:
raise HTTPException(status_code=400, detail="Only JPG/PNG images are supported")
# Preprocess image
max_size = 1024 # Optimize for performance
img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
inputs = processor(images=img, return_tensors="pt")
# Run inference (within 5 seconds per TSD)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)
# Get top prediction
predicted_idx = torch.argmax(probabilities, dim=1).item()
milestone = model.config.id2label.get(predicted_idx, "Unknown Milestone")
# Mock percent_complete (since model doesn't provide it)
percent_complete = 75 # Adjust based on milestone or train a model to predict this
return {
"milestone": milestone,
"percent_complete": percent_complete
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
def process_image(image, project_name):
try:
# Validate inputs
if image is None:
return "Error: Please upload an image to proceed.", "Pending", "", ""
if not project_name:
return "Error: Please enter a project name to proceed.", "Pending", "", ""
if not re.match(r'^[a-zA-Z0-9\s-]+$', project_name):
return "Error: Project name must be alphanumeric (letters, numbers, spaces, or hyphens).", "Pending", "", ""
# Validate image size and type
image_size_mb = os.path.getsize(image) / (1024 * 1024)
if image_size_mb > 20:
return "Error: Image size exceeds 20MB.", "Failure", "", ""
if not image.lower().endswith(('.jpg', '.jpeg', '.png')):
return "Error: Only JPG/PNG images are supported.", "Failure", "", ""
# Preprocess image
img = Image.open(image).convert("RGB")
max_size = 1024
img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
inputs = processor(images=img, return_tensors="pt")
# Run inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)
# Get top predictions
top_probs, top_indices = torch.topk(probabilities, 3, dim=1)
top_probs = top_probs[0].tolist()
top_indices = top_indices[0].tolist()
# Map indices to labels
predicted_idx = top_indices[0]
milestone = model.config.id2label.get(predicted_idx, "Unknown Milestone")
# Format top predictions
prediction_details = "\n".join([f"{model.config.id2label.get(idx, 'Unknown Milestone')}: {prob:.2f}" for idx, prob in zip(top_indices, top_probs)])
# Update Salesforce record
record = {
"Name": project_name,
"Current_Milestone__c": milestone,
"Last_Updated_On__c": datetime.now().isoformat(),
"Upload_Status__c": "Success",
"Comments__c": f"AI Prediction: {milestone}",
"Version__c": 1
}
try:
project_name = project_name.replace("'", "''")
query = f"SELECT Id, Version__c FROM Construction_Project__c WHERE Name = '{project_name}'"
result = sf.query(query)
if result["totalSize"] > 0:
project_id = result["records"][0]["Id"]
current_version = result["records"][0].get("Version__c", 0)
record["Version__c"] = current_version + 1
sf.Construction_Project__c.update(project_id, record)
else:
sf.Construction_Project__c.create(record)
except Exception as e:
return f"Error: Failed to update Salesforce - {str(e)}", "Failure", "", prediction_details
return (
f"Success: Milestone: {milestone}",
"Success",
milestone,
prediction_details
)
except Exception as e:
return f"Error: {str(e)}", "Failure", "", ""
# Gradio interface (for manual testing)
with gr.Blocks(css=".gradio-container {background-color: #f0f4f8; font-family: Arial;} .title {color: #2c3e50; font-size: 24px; text-align: center;}") as demo:
gr.Markdown("<h1 class='title'>Construction Milestone Detector</h1>")
project_name = gr.Textbox(label="Project Name", placeholder="Enter project name")
image_input = gr.Image(type="filepath", label="Upload Construction Site Photo (JPG/PNG, ≤ 20MB)")
submit_button = gr.Button("Process Image")
output_text = gr.Textbox(label="Result")
upload_status = gr.Textbox(label="Upload Status")
milestone = gr.Textbox(label="Detected Milestone")
prediction_details = gr.Textbox(label="Top Predictions")
progress = gr.Slider(0, 100, label="Processing Progress", interactive=False, value=0)
def update_progress():
return 50
def complete_progress():
return 100
submit_button.click(
fn=update_progress,
outputs=progress
).then(
fn=process_image,
inputs=[image_input, project_name],
outputs=[output_text, upload_status, milestone, prediction_details]
).then(
fn=complete_progress,
outputs=progress
)
# Mount FastAPI app to Gradio
demo.launch()