nasreshsuguru commited on
Commit
a5fe846
·
verified ·
1 Parent(s): ac7a8d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -11
app.py CHANGED
@@ -7,6 +7,10 @@ import re
7
  from dotenv import load_dotenv
8
  from simple_salesforce import Salesforce
9
  from datetime import datetime
 
 
 
 
10
 
11
  # Load environment variables
12
  load_dotenv()
@@ -16,10 +20,11 @@ SF_PASSWORD = os.getenv("SF_PASSWORD")
16
  SF_SECURITY_TOKEN = os.getenv("SF_SECURITY_TOKEN")
17
  SF_CONSUMER_KEY = os.getenv("SF_CONSUMER_KEY")
18
  SF_CONSUMER_SECRET = os.getenv("SF_CONSUMER_SECRET")
 
19
 
20
- # Validate Salesforce credentials
21
  if not all([SF_USERNAME, SF_PASSWORD, SF_SECURITY_TOKEN, SF_CONSUMER_KEY, SF_CONSUMER_SECRET]):
22
- raise ValueError("Missing Salesforce credentials. Please set SF_USERNAME, SF_PASSWORD, SF_SECURITY_TOKEN, SF_CONSUMER_KEY, and SF_CONSUMER_SECRET in environment variables.")
23
 
24
  # Initialize Salesforce connection
25
  try:
@@ -36,15 +41,71 @@ except Exception as e:
36
  raise
37
 
38
  # Load pre-trained model and processor
39
- model_name = "google/vit-base-patch16-224" # Placeholder; replace with "nasreshsuguru/construction-milestone-detector" if trained
40
  processor = ViTImageProcessor.from_pretrained(model_name)
41
  model = ViTForImageClassification.from_pretrained(model_name)
42
  model.eval()
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def process_image(image, project_name):
45
- """
46
- Process uploaded image, predict construction milestone, and update Salesforce.
47
- """
48
  try:
49
  # Validate inputs
50
  if image is None:
@@ -63,7 +124,7 @@ def process_image(image, project_name):
63
 
64
  # Preprocess image
65
  img = Image.open(image).convert("RGB")
66
- max_size = 1024 # Optimize for performance
67
  img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
68
  inputs = processor(images=img, return_tensors="pt")
69
 
@@ -96,7 +157,7 @@ def process_image(image, project_name):
96
  }
97
 
98
  try:
99
- project_name = project_name.replace("'", "''") # Basic escaping
100
  query = f"SELECT Id, Version__c FROM Construction_Project__c WHERE Name = '{project_name}'"
101
  result = sf.query(query)
102
  if result["totalSize"] > 0:
@@ -119,7 +180,7 @@ def process_image(image, project_name):
119
  except Exception as e:
120
  return f"Error: {str(e)}", "Failure", "", ""
121
 
122
- # Gradio interface
123
  with gr.Blocks(css=".gradio-container {background-color: #f0f4f8; font-family: Arial;} .title {color: #2c3e50; font-size: 24px; text-align: center;}") as demo:
124
  gr.Markdown("<h1 class='title'>Construction Milestone Detector</h1>")
125
  project_name = gr.Textbox(label="Project Name", placeholder="Enter project name")
@@ -149,5 +210,5 @@ with gr.Blocks(css=".gradio-container {background-color: #f0f4f8; font-family: A
149
  outputs=progress
150
  )
151
 
152
- if __name__ == "__main__":
153
- demo.launch() # Fixed: Properly indented under the if statement
 
7
  from dotenv import load_dotenv
8
  from simple_salesforce import Salesforce
9
  from datetime import datetime
10
+ from fastapi import FastAPI, HTTPException, Security, Depends
11
+ from fastapi.security import APIKeyHeader
12
+ import base64
13
+ import io
14
 
15
  # Load environment variables
16
  load_dotenv()
 
20
  SF_SECURITY_TOKEN = os.getenv("SF_SECURITY_TOKEN")
21
  SF_CONSUMER_KEY = os.getenv("SF_CONSUMER_KEY")
22
  SF_CONSUMER_SECRET = os.getenv("SF_CONSUMER_SECRET")
23
+ API_KEY = os.getenv("API_KEY", "your-api-key-here") # Set in Space Secrets
24
 
25
+ # Validate Salesforce credentials (for Gradio UI)
26
  if not all([SF_USERNAME, SF_PASSWORD, SF_SECURITY_TOKEN, SF_CONSUMER_KEY, SF_CONSUMER_SECRET]):
27
+ raise ValueError("Missing Salesforce credentials. Set SF_USERNAME, SF_PASSWORD, SF_SECURITY_TOKEN, SF_CONSUMER_KEY, and SF_CONSUMER_SECRET in environment variables.")
28
 
29
  # Initialize Salesforce connection
30
  try:
 
41
  raise
42
 
43
  # Load pre-trained model and processor
44
+ model_name = "google/vit-base-patch16-224"
45
  processor = ViTImageProcessor.from_pretrained(model_name)
46
  model = ViTForImageClassification.from_pretrained(model_name)
47
  model.eval()
48
 
49
+ # FastAPI app for API endpoint
50
+ app = FastAPI()
51
+
52
+ # API Key authentication
53
+ api_key_header = APIKeyHeader(name="X-API-Key")
54
+ async def verify_api_key(api_key: str = Security(api_key_header)):
55
+ if api_key != API_KEY:
56
+ raise HTTPException(status_code=401, detail="Invalid API Key")
57
+ return api_key
58
+
59
+ @app.post("/predict-milestone")
60
+ async def predict_milestone(payload: dict, api_key: str = Depends(verify_api_key)):
61
+ try:
62
+ # Validate payload
63
+ if "image" not in payload:
64
+ raise HTTPException(status_code=400, detail="Image field is required")
65
+
66
+ # Decode base64 image
67
+ image_data = payload["image"]
68
+ if image_data.startswith("data:image"):
69
+ image_data = image_data.split(",")[1] # Remove data URI prefix
70
+ img_bytes = base64.b64decode(image_data)
71
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
72
+
73
+ # Validate image size (max 20MB per TSD)
74
+ img_bytes_size = len(img_bytes) / (1024 * 1024)
75
+ if img_bytes_size > 20:
76
+ raise HTTPException(status_code=400, detail="Image size exceeds 20MB")
77
+
78
+ # Validate image type
79
+ if not img.format.lower() in ["jpeg", "png"]:
80
+ raise HTTPException(status_code=400, detail="Only JPG/PNG images are supported")
81
+
82
+ # Preprocess image
83
+ max_size = 1024 # Optimize for performance
84
+ img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
85
+ inputs = processor(images=img, return_tensors="pt")
86
+
87
+ # Run inference (within 5 seconds per TSD)
88
+ with torch.no_grad():
89
+ outputs = model(**inputs)
90
+ logits = outputs.logits
91
+ probabilities = torch.softmax(logits, dim=1)
92
+
93
+ # Get top prediction
94
+ predicted_idx = torch.argmax(probabilities, dim=1).item()
95
+ milestone = model.config.id2label.get(predicted_idx, "Unknown Milestone")
96
+
97
+ # Mock percent_complete (since model doesn't provide it)
98
+ percent_complete = 75 # Adjust based on milestone or train a model to predict this
99
+
100
+ return {
101
+ "milestone": milestone,
102
+ "percent_complete": percent_complete
103
+ }
104
+
105
+ except Exception as e:
106
+ raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
107
+
108
  def process_image(image, project_name):
 
 
 
109
  try:
110
  # Validate inputs
111
  if image is None:
 
124
 
125
  # Preprocess image
126
  img = Image.open(image).convert("RGB")
127
+ max_size = 1024
128
  img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
129
  inputs = processor(images=img, return_tensors="pt")
130
 
 
157
  }
158
 
159
  try:
160
+ project_name = project_name.replace("'", "''")
161
  query = f"SELECT Id, Version__c FROM Construction_Project__c WHERE Name = '{project_name}'"
162
  result = sf.query(query)
163
  if result["totalSize"] > 0:
 
180
  except Exception as e:
181
  return f"Error: {str(e)}", "Failure", "", ""
182
 
183
+ # Gradio interface (for manual testing)
184
  with gr.Blocks(css=".gradio-container {background-color: #f0f4f8; font-family: Arial;} .title {color: #2c3e50; font-size: 24px; text-align: center;}") as demo:
185
  gr.Markdown("<h1 class='title'>Construction Milestone Detector</h1>")
186
  project_name = gr.Textbox(label="Project Name", placeholder="Enter project name")
 
210
  outputs=progress
211
  )
212
 
213
+ # Mount FastAPI app to Gradio
214
+ demo.launch()