lokeshloki143 commited on
Commit
17de6a3
·
verified ·
1 Parent(s): 3424e79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -65
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import logging
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
@@ -7,7 +8,7 @@ from pydantic import BaseModel
7
  from simple_salesforce import Salesforce
8
  from dotenv import load_dotenv
9
 
10
- # Load environment variables from .env
11
  load_dotenv()
12
 
13
  # Configure logging
@@ -33,8 +34,8 @@ if not HF_TOKEN:
33
  logging.warning("HF_TOKEN not set. Using public models only.")
34
 
35
  # Model configuration
36
- MODEL_PATH = "facebook/bart-large" # Public model for testing
37
- # MODEL_PATH = "your_actual_username/fine_tuned_bart_construction" # Uncomment and replace after uploading your model
38
 
39
  try:
40
  logging.info(f"Attempting to load model from {MODEL_PATH}")
@@ -51,11 +52,11 @@ class ChecklistInput(BaseModel):
51
  project_id: str = "Unknown"
52
  project_name: str = "Unknown Project"
53
  milestones: str = "No milestones provided"
54
- record_id: str = None # Supervisor_AI_Coaching__c record ID
55
- supervisor_id: str = None # Supervisor_ID__c
56
- project_id_sf: str = None # Project_ID__c
57
- reflection_log: str = None # Reflection_Log__c
58
- download_link: str = None # Download_Link__c
59
 
60
  # Initialize FastAPI
61
  app = FastAPI()
@@ -63,7 +64,6 @@ app = FastAPI()
63
  @app.post("/generate")
64
  async def generate_checklist(data: ChecklistInput):
65
  try:
66
- # Generate checklist and tips
67
  inputs = f"Role: {data.role} Project: {data.project_id} ({data.project_name}) Milestones: {data.milestones}"
68
  logging.info(f"Generating checklist for inputs: {inputs[:100]}...")
69
  input_ids = tokenizer(inputs, return_tensors="pt", max_length=128, truncation=True).input_ids
@@ -72,7 +72,6 @@ async def generate_checklist(data: ChecklistInput):
72
  tips = "1. Prioritize safety checks\n2. Review milestones\n3. Log progress"
73
  kpi_flag = "delay" in data.milestones.lower() or "behind" in data.milestones.lower()
74
 
75
- # Update Salesforce if record_id is provided
76
  if data.record_id:
77
  try:
78
  sf = get_salesforce_connection()
@@ -96,14 +95,12 @@ async def generate_checklist(data: ChecklistInput):
96
  'Project_ID__c': data.project_id_sf if data.project_id_sf else existing_record.get('Project_ID__c'),
97
  'Reflection_Log__c': data.reflection_log if data.reflection_log else existing_record.get('Reflection_Log__c', ''),
98
  'Download_Link__c': data.download_link if data.download_link else existing_record.get('Download_Link__c', '')
99
- # Name is read-only, not updated
100
  }
101
  sf.Supervisor_AI_Coaching__c.update(data.record_id, update_data)
102
  logging.info(f"Updated Salesforce record {data.record_id} with fields: {update_data}")
103
  except Exception as sf_e:
104
  logging.error(f"Failed to update Salesforce: {sf_e}")
105
 
106
- logging.info("Checklist and tips generated successfully")
107
  return {
108
  "checklist": checklist,
109
  "tips": tips,
@@ -113,9 +110,49 @@ async def generate_checklist(data: ChecklistInput):
113
  logging.error(f"Error generating checklist: {e}")
114
  return {"error": str(e)}
115
 
116
- # Gradio interface function
117
- def gradio_generate_checklist(role, project_id, project_name, milestones, record_id="", supervisor_id="", project_id_sf="", reflection_log="", download_link=""):
118
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  inputs = f"Role: {role} Project: {project_id} ({project_name}) Milestones: {milestones}"
120
  input_ids = tokenizer(inputs, return_tensors="pt", max_length=128, truncation=True).input_ids
121
  outputs = model.generate(input_ids, max_length=128, num_beams=4, early_stopping=True)
@@ -123,63 +160,60 @@ def gradio_generate_checklist(role, project_id, project_name, milestones, record
123
  tips = "1. Prioritize safety checks\n2. Review milestones\n3. Log progress"
124
  kpi_flag = "delay" in milestones.lower() or "behind" in milestones.lower()
125
 
126
- status = "No Salesforce update (record_id not provided)"
127
- if record_id:
128
- sf = get_salesforce_connection()
129
- existing_record = sf.Supervisor_AI_Coaching__c.get(record_id, default={
130
- 'Name': '',
131
- 'Supervisor_ID__c': None,
132
- 'Project_ID__c': None,
133
- 'Reflection_Log__c': '',
134
- 'Download_Link__c': '',
135
- 'Engagement_Score__c': 0,
136
- 'KPI_Flag__c': False,
137
- 'Daily_Checklist__c': '',
138
- 'Suggested_Tips__c': ''
139
- })
140
- update_data = {
141
- 'Daily_Checklist__c': checklist,
142
- 'Suggested_Tips__c': tips,
143
- 'Engagement_Score__c': existing_record.get('Engagement_Score__c', 0) + 10,
144
- 'KPI_Flag__c': kpi_flag,
145
- 'Supervisor_ID__c': supervisor_id if supervisor_id else existing_record.get('Supervisor_ID__c'),
146
- 'Project_ID__c': project_id_sf if project_id_sf else existing_record.get('Project_ID__c'),
147
- 'Reflection_Log__c': reflection_log if reflection_log else existing_record.get('Reflection_Log__c', ''),
148
- 'Download_Link__c': download_link if download_link else existing_record.get('Download_Link__c', '')
149
- # Name is read-only
150
- }
151
- sf.Supervisor_AI_Coaching__c.update(record_id, update_data)
152
- status = f"Updated Salesforce record {record_id} with fields: {update_data}"
153
 
154
  return checklist, tips, kpi_flag, status
155
  except Exception as e:
156
  return f"Error: {str(e)}", "", False, ""
157
 
158
  # Define Gradio interface
159
- iface = gr.Interface(
160
- fn=gradio_generate_checklist,
161
- inputs=[
162
- gr.Textbox(label="Role", value="Supervisor"),
163
- gr.Textbox(label="Project ID", value="P001"),
164
- gr.Textbox(label="Project Name", value="Building A"),
165
- gr.Textbox(label="Milestones", value="Complete foundation by 5/15"),
166
- gr.Textbox(label="Salesforce Record ID (optional)", value=""),
167
- gr.Textbox(label="Supervisor ID (Salesforce User ID, optional)", value=""),
168
- gr.Textbox(label="Project ID (Salesforce Project__c ID, optional)", value=""),
169
- gr.Textbox(label="Reflection Log (optional)", value=""),
170
- gr.Textbox(label="Download Link (optional)", value="")
171
- ],
172
- outputs=[
173
- gr.Textbox(label="Checklist"),
174
- gr.Textbox(label="Tips"),
175
- gr.Checkbox(label="KPI Flag"),
176
- gr.Textbox(label="Salesforce Status")
177
- ],
178
- title="AI Coach for Site Supervisors",
179
- description="Generate daily checklists and tips, and update Salesforce records."
180
- )
181
-
182
- # Mount FastAPI to Gradio
 
 
 
 
 
 
 
 
 
 
 
 
183
  iface.app = app
184
 
185
  if __name__ == "__main__":
@@ -189,3 +223,4 @@ if __name__ == "__main__":
189
  except Exception as e:
190
  logging.error(f"Failed to launch Gradio interface: {e}")
191
  raise e
 
 
1
+ ```python
2
  import gradio as gr
3
  import logging
4
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
8
  from simple_salesforce import Salesforce
9
  from dotenv import load_dotenv
10
 
11
+ # Load environment variables
12
  load_dotenv()
13
 
14
  # Configure logging
 
34
  logging.warning("HF_TOKEN not set. Using public models only.")
35
 
36
  # Model configuration
37
+ MODEL_PATH = "facebook/bart-large" # Public model
38
+ # MODEL_PATH = "your_actual_username/fine_tuned_bart_construction" # Uncomment after uploading
39
 
40
  try:
41
  logging.info(f"Attempting to load model from {MODEL_PATH}")
 
52
  project_id: str = "Unknown"
53
  project_name: str = "Unknown Project"
54
  milestones: str = "No milestones provided"
55
+ record_id: str = None
56
+ supervisor_id: str = None
57
+ project_id_sf: str = None
58
+ reflection_log: str = None
59
+ download_link: str = None
60
 
61
  # Initialize FastAPI
62
  app = FastAPI()
 
64
  @app.post("/generate")
65
  async def generate_checklist(data: ChecklistInput):
66
  try:
 
67
  inputs = f"Role: {data.role} Project: {data.project_id} ({data.project_name}) Milestones: {data.milestones}"
68
  logging.info(f"Generating checklist for inputs: {inputs[:100]}...")
69
  input_ids = tokenizer(inputs, return_tensors="pt", max_length=128, truncation=True).input_ids
 
72
  tips = "1. Prioritize safety checks\n2. Review milestones\n3. Log progress"
73
  kpi_flag = "delay" in data.milestones.lower() or "behind" in data.milestones.lower()
74
 
 
75
  if data.record_id:
76
  try:
77
  sf = get_salesforce_connection()
 
95
  'Project_ID__c': data.project_id_sf if data.project_id_sf else existing_record.get('Project_ID__c'),
96
  'Reflection_Log__c': data.reflection_log if data.reflection_log else existing_record.get('Reflection_Log__c', ''),
97
  'Download_Link__c': data.download_link if data.download_link else existing_record.get('Download_Link__c', '')
 
98
  }
99
  sf.Supervisor_AI_Coaching__c.update(data.record_id, update_data)
100
  logging.info(f"Updated Salesforce record {data.record_id} with fields: {update_data}")
101
  except Exception as sf_e:
102
  logging.error(f"Failed to update Salesforce: {sf_e}")
103
 
 
104
  return {
105
  "checklist": checklist,
106
  "tips": tips,
 
110
  logging.error(f"Error generating checklist: {e}")
111
  return {"error": str(e)}
112
 
113
+ # Login and display records
114
+ def login_and_display(project_id_sf):
115
  try:
116
+ sf = get_salesforce_connection()
117
+ query = f"SELECT Id, Name, Supervisor_ID__c, Project_ID__c, Daily_Checklist__c, Suggested_Tips__c, Reflection_Log__c, Engagement_Score__c, KPI_Flag__c, Download_Link__c FROM Supervisor_AI_Coaching__c WHERE Project_ID__c = '{project_id_sf}'"
118
+ records = sf.query(query)["records"]
119
+ if not records:
120
+ return "No records found for Project ID.", "", False, ""
121
+
122
+ output = "Supervisor_AI_Coaching__c Records:\n"
123
+ for record in records:
124
+ output += (
125
+ f"Record ID: {record['Id']}\n"
126
+ f"Name: {record['Name']}\n"
127
+ f"Supervisor ID: {record['Supervisor_ID__c']}\n"
128
+ f"Project ID: {record['Project_ID__c']}\n"
129
+ f"Daily Checklist: {record['Daily_Checklist__c'] or 'N/A'}\n"
130
+ f"Suggested Tips: {record['Suggested_Tips__c'] or 'N/A'}\n"
131
+ f"Reflection Log: {record['Reflection_Log__c'] or 'N/A'}\n"
132
+ f"Engagement Score: {record['Engagement_Score__c'] or 0}%\n"
133
+ f"KPI Flag: {record['KPI_Flag__c']}\n"
134
+ f"Download Link: {record['Download_Link__c'] or 'N/A'}\n"
135
+ f"{'-'*50}\n"
136
+ )
137
+ return output, "", False, ""
138
+ except Exception as e:
139
+ return f"Error querying Salesforce: {str(e)}", "", False, ""
140
+
141
+ # Generate checklist from record
142
+ def gradio_generate_checklist(record_id, role="Supervisor", project_id="Unknown", project_name="Unknown Project", milestones="No milestones provided", supervisor_id="", project_id_sf="", reflection_log="", download_link=""):
143
+ try:
144
+ sf = get_salesforce_connection()
145
+ existing_record = sf.Supervisor_AI_Coaching__c.get(record_id, default={
146
+ 'Name': '',
147
+ 'Supervisor_ID__c': None,
148
+ 'Project_ID__c': None,
149
+ 'Reflection_Log__c': '',
150
+ 'Download_Link__c': '',
151
+ 'Engagement_Score__c': 0,
152
+ 'KPI_Flag__c': False,
153
+ 'Daily_Checklist__c': '',
154
+ 'Suggested_Tips__c': ''
155
+ })
156
  inputs = f"Role: {role} Project: {project_id} ({project_name}) Milestones: {milestones}"
157
  input_ids = tokenizer(inputs, return_tensors="pt", max_length=128, truncation=True).input_ids
158
  outputs = model.generate(input_ids, max_length=128, num_beams=4, early_stopping=True)
 
160
  tips = "1. Prioritize safety checks\n2. Review milestones\n3. Log progress"
161
  kpi_flag = "delay" in milestones.lower() or "behind" in milestones.lower()
162
 
163
+ update_data = {
164
+ 'Daily_Checklist__c': checklist,
165
+ 'Suggested_Tips__c': tips,
166
+ 'Engagement_Score__c': existing_record.get('Engagement_Score__c', 0) + 10,
167
+ 'KPI_Flag__c': kpi_flag,
168
+ 'Supervisor_ID__c': supervisor_id if supervisor_id else existing_record.get('Supervisor_ID__c'),
169
+ 'Project_ID__c': project_id_sf if project_id_sf else existing_record.get('Project_ID__c'),
170
+ 'Reflection_Log__c': reflection_log if reflection_log else existing_record.get('Reflection_Log__c', ''),
171
+ 'Download_Link__c': download_link if download_link else existing_record.get('Download_Link__c', '')
172
+ }
173
+ sf.Supervisor_AI_Coaching__c.update(record_id, update_data)
174
+ status = f"Updated Salesforce record {record_id} with fields: {update_data}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  return checklist, tips, kpi_flag, status
177
  except Exception as e:
178
  return f"Error: {str(e)}", "", False, ""
179
 
180
  # Define Gradio interface
181
+ with gr.Blocks() as iface:
182
+ gr.Markdown("# AI Coach for Site Supervisors")
183
+ gr.Markdown("Enter a Project ID to view Supervisor_AI_Coaching__c records and generate checklists.")
184
+
185
+ with gr.Tab("Login"):
186
+ project_id_input = gr.Textbox(label="Project ID (Salesforce Project__c ID)", placeholder="Enter Project ID")
187
+ login_button = gr.Button("Submit")
188
+ records_output = gr.Textbox(label="Records", lines=10)
189
+ login_button.click(
190
+ fn=login_and_display,
191
+ inputs=project_id_input,
192
+ outputs=[records_output, gr.Textbox(visible=False), gr.Checkbox(visible=False), gr.Textbox(visible=False)]
193
+ )
194
+
195
+ with gr.Tab("Generate Checklist"):
196
+ record_id = gr.Textbox(label="Record ID", placeholder="Enter Record ID from above")
197
+ role = gr.Textbox(label="Role", value="Supervisor")
198
+ project_id = gr.Textbox(label="Project ID", value="P001")
199
+ project_name = gr.Textbox(label="Project Name", value="Building A")
200
+ milestones = gr.Textbox(label="Milestones", value="Complete foundation by 5/15")
201
+ supervisor_id = gr.Textbox(label="Supervisor ID (Salesforce User ID, optional)", value="")
202
+ project_id_sf = gr.Textbox(label="Project ID (Salesforce Project__c ID, optional)", value="")
203
+ reflection_log = gr.Textbox(label="Reflection Log (optional)", value="")
204
+ download_link = gr.Textbox(label="Download Link (optional)", value="")
205
+ generate_button = gr.Button("Generate and Update")
206
+ checklist_output = gr.Textbox(label="Checklist")
207
+ tips_output = gr.Textbox(label="Tips")
208
+ kpi_flag_output = gr.Checkbox(label="KPI Flag")
209
+ status_output = gr.Textbox(label="Salesforce Status")
210
+ generate_button.click(
211
+ fn=gradio_generate_checklist,
212
+ inputs=[record_id, role, project_id, project_name, milestones, supervisor_id, project_id_sf, reflection_log, download_link],
213
+ outputs=[checklist_output, tips_output, kpi_flag_output, status_output]
214
+ )
215
+
216
+ # Mount FastAPI
217
  iface.app = app
218
 
219
  if __name__ == "__main__":
 
223
  except Exception as e:
224
  logging.error(f"Failed to launch Gradio interface: {e}")
225
  raise e
226
+ ```