lokeshloki143 commited on
Commit
0d3b455
·
verified ·
1 Parent(s): d6f6675

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +234 -0
app.py CHANGED
@@ -1,6 +1,240 @@
1
  import gradio as gr
 
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  # Initialize model and tokenizer
6
  model_name = "distilgpt2"
 
1
  import gradio as gr
2
+ import torchimport gradio as gr
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from simple_salesforce import Salesforce
6
+ from dotenv import load_dotenv
7
+ import os
8
+ import json
9
+ from fastapi import FastAPI, HTTPException
10
+ import uvicorn
11
+
12
+ # Load environment variables from .env file
13
+ load_dotenv()
14
+
15
+ # Salesforce credentials
16
+ SF_USERNAME = os.getenv('SF_USERNAME')
17
+ SF_PASSWORD = os.getenv('SF_PASSWORD')
18
+ SF_SECURITY_TOKEN = os.getenv('SF_SECURITY_TOKEN')
19
+ SF_DOMAIN = os.getenv('SF_DOMAIN')
20
+ HUGGINGFACE_API_KEY = os.getenv('HUGGINGFACE_API_KEY')
21
+
22
+ # Initialize Salesforce connection
23
+ try:
24
+ sf = Salesforce(
25
+ username=SF_USERNAME,
26
+ password=SF_PASSWORD,
27
+ security_token=SF_SECURITY_TOKEN,
28
+ domain=SF_DOMAIN
29
+ )
30
+ except Exception as e:
31
+ print(f"Error connecting to Salesforce: {e}")
32
+
33
+ # Initialize model and tokenizer
34
+ model_name = "distilgpt2"
35
+ try:
36
+ tokenizer = AutoTokenizer.from_pretrained(
37
+ model_name,
38
+ cache_dir="./model_cache",
39
+ token=HUGGINGFACE_API_KEY
40
+ )
41
+ model = AutoModelForCausalLM.from_pretrained(
42
+ model_name,
43
+ cache_dir="./model_cache",
44
+ token=HUGGINGFACE_API_KEY
45
+ )
46
+ except Exception as e:
47
+ print(f"Error loading model: {e}")
48
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_API_KEY)
49
+ model = AutoModelForCausalLM.from_pretrained(model_name, token=HUGGINGFACE_API_KEY)
50
+
51
+ # Set pad_token to eos_token if not already set
52
+ if tokenizer.pad_token is None:
53
+ tokenizer.pad_token = tokenizer.eos_token
54
+ model.config.pad_token_id = tokenizer.eos_token_id
55
+
56
+ # Simplified prompt template
57
+ PROMPT_TEMPLATE = """Role: {role}
58
+ Project: {project_id}
59
+ Milestones:
60
+ - {milestones_list}
61
+ Reflection: {reflection}
62
+ Generate:
63
+ Checklist:
64
+ - {milestones_list}
65
+ Suggestions:
66
+ {suggestions_list}
67
+ Quote:
68
+ {your_motivational_quote}"""
69
+
70
+ def generate_outputs(role, project_id, milestones, reflection):
71
+ # Input validation
72
+ if not all([role, project_id, milestones, reflection]):
73
+ return "Error: All fields are required.", "", ""
74
+
75
+ # Process milestones
76
+ milestones_list = "\n- ".join([m.strip() for m in milestones.split(",") if m.strip()])
77
+ if not milestones_list:
78
+ return "Error: At least one valid milestone is required.", "", ""
79
+
80
+ # Generate suggestions based on reflection
81
+ suggestions_list = []
82
+ reflection_lower = reflection.lower()
83
+ if "delays" in reflection_lower:
84
+ suggestions_list.extend(["Adjust timelines for delays.", "Communicate with stakeholders."])
85
+ if "weather" in reflection_lower:
86
+ suggestions_list.extend(["Ensure rain gear availability.", "Monitor weather updates."])
87
+ if "equipment" in reflection_lower:
88
+ suggestions_list.extend(["Inspect equipment.", "Schedule maintenance."])
89
+ suggestions_list = "\n- ".join(suggestions_list) if suggestions_list else "No specific suggestions."
90
+
91
+ # Format prompt
92
+ prompt = PROMPT_TEMPLATE.format(
93
+ role=role,
94
+ project_id=project_id,
95
+ milestones_list=milestones_list.replace("\n- ", "\n- "),
96
+ reflection=reflection,
97
+ suggestions_list=suggestions_list,
98
+ your_motivational_quote="Your motivational quote here"
99
+ )
100
+
101
+ # Tokenize input
102
+ try:
103
+ inputs = tokenizer(
104
+ prompt,
105
+ return_tensors="pt",
106
+ max_length=512,
107
+ truncation=True,
108
+ padding=True
109
+ )
110
+ except Exception as e:
111
+ return f"Error in tokenization: {e}", "", ""
112
+
113
+ # Generate output
114
+ try:
115
+ with torch.no_grad():
116
+ outputs = model.generate(
117
+ input_ids=inputs['input_ids'],
118
+ attention_mask=inputs['attention_mask'],
119
+ max_length=600,
120
+ num_return_sequences=1,
121
+ no_repeat_ngram_size=2,
122
+ top_k=50,
123
+ top_p=0.95,
124
+ temperature=0.7,
125
+ pad_token_id=tokenizer.eos_token_id,
126
+ do_sample=True
127
+ )
128
+
129
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
130
+ except Exception as e:
131
+ return f"Error in generation: {e}", "", ""
132
+
133
+ # Extract sections
134
+ checklist = suggestions = quote = "Not generated."
135
+ try:
136
+ if "Checklist:" in generated_text:
137
+ checklist_start = generated_text.find("Checklist:") + 10
138
+ suggestions_start = generated_text.find("Suggestions:", checklist_start)
139
+ if suggestions_start == -1:
140
+ suggestions_start = len(generated_text)
141
+ checklist = generated_text[checklist_start:suggestions_start].strip()
142
+
143
+ if "Suggestions:" in generated_text:
144
+ suggestions_start = generated_text.find("Suggestions:") + 12
145
+ quote_start = generated_text.find("Quote:", suggestions_start)
146
+ if quote_start == -1:
147
+ quote_start = len(generated_text)
148
+ suggestions = generated_text[suggestions_start:quote_start].strip()
149
+
150
+ if "Quote:" in generated_text:
151
+ quote_start = generated_text.find("Quote:") + 7
152
+ quote = generated_text[quote_start:].strip()
153
+ except Exception as e:
154
+ return f"Error parsing output: {e}", "", ""
155
+
156
+ # Save to Salesforce
157
+ try:
158
+ project = sf.Project__c.get(project_id)
159
+ supervisor_id = project['Supervisor__c']
160
+ coaching_data = {
161
+ 'Supervisor_ID__c': supervisor_id,
162
+ 'Project_ID__c': project_id,
163
+ 'Daily_Checklist__c': checklist,
164
+ 'Suggested_Tips__c': suggestions,
165
+ 'KPI_Flag__c': 'delay' in suggestions.lower() or 'issue' in suggestions.lower(),
166
+ 'Last_Refresh_Date__c': datetime.now().isoformat()
167
+ }
168
+ sf.Supervisor_AI_Coaching__c.create(coaching_data)
169
+ except Exception as e:
170
+ print(f"Error saving to Salesforce: {e}")
171
+
172
+ return checklist, suggestions, quote
173
+
174
+ def create_interface():
175
+ with gr.Blocks() as demo:
176
+ gr.Markdown("### Construction Supervisor AI Coach")
177
+ with gr.Row():
178
+ role = gr.Dropdown(choices=["Supervisor", "Foreman", "Project Manager"], label="Role", value="Supervisor")
179
+ project_id = gr.Textbox(label="Project ID", placeholder="e.g., PROJ-001")
180
+ milestones = gr.Textbox(
181
+ label="Milestones (comma-separated)",
182
+ placeholder="e.g., Foundation complete, Framing started, Roof installed"
183
+ )
184
+ reflection = gr.Textbox(
185
+ label="Reflection",
186
+ lines=3,
187
+ placeholder="e.g., Facing delays due to weather and equipment issues."
188
+ )
189
+ with gr.Row():
190
+ submit = gr.Button("Generate")
191
+ clear = gr.Button("Clear")
192
+ checklist_output = gr.Textbox(label="Checklist", lines=4)
193
+ suggestions_output = gr.Textbox(label="Suggestions", lines=4)
194
+ quote_output = gr.Textbox(label="Quote", lines=2)
195
+
196
+ submit.click(
197
+ fn=generate_outputs,
198
+ inputs=[role, project_id, milestones, reflection],
199
+ outputs=[checklist_output, suggestions_output, quote_output]
200
+ )
201
+ clear.click(
202
+ fn=lambda: ("Supervisor", "", "", ""),
203
+ inputs=None,
204
+ outputs=[role, project_id, milestones, reflection]
205
+ )
206
+
207
+ return demo
208
+
209
+ # FastAPI for Salesforce integration
210
+ app = FastAPI()
211
+
212
+ @app.post("/api/predict")
213
+ async def predict(data: dict):
214
+ try:
215
+ role = data.get('role')
216
+ project_id = data.get('project_id')
217
+ milestones = data.get('milestones')
218
+ reflection = data.get('reflection')
219
+ checklist, suggestions, quote = generate_outputs(role, project_id, milestones, reflection)
220
+ return {
221
+ 'checklist': checklist,
222
+ 'suggestions': suggestions,
223
+ 'quote': quote
224
+ }
225
+ except Exception as e:
226
+ raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
227
+
228
+ if __name__ == "__main__":
229
+ try:
230
+ # For local testing with Gradio
231
+ demo = create_interface()
232
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
233
+ except Exception as e:
234
+ print(f"Error launching Gradio interface: {e}")
235
+ # Run FastAPI for production
236
+ uvicorn.run(app, host="0.0.0.0", port=7860)
237
+ from transformers import AutoModelForCausalLM, AutoTokenizer
238
 
239
  # Initialize model and tokenizer
240
  model_name = "distilgpt2"