lokeshloki143 commited on
Commit
d282ec1
·
verified ·
1 Parent(s): b0cfc0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -18
app.py CHANGED
@@ -1,11 +1,11 @@
1
-
2
  import gradio as gr
3
  import logging
4
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
5
  import os
6
- from fastapi import FastAPI, Request
7
  from pydantic import BaseModel
8
- import uvicorn
 
9
 
10
  # Configure logging
11
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -13,19 +13,50 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
13
  # Load Hugging Face token from environment variable
14
  HF_TOKEN = os.getenv("HF_TOKEN")
15
  if not HF_TOKEN:
16
- logging.warning("HF_TOKEN not set. Private repositories may be inaccessible.")
17
 
18
  # Model configuration
19
- MODEL_PATH = "your_username/fine_tuned_bart_construction" # Replace with your model path
 
20
 
21
  try:
22
- # Load model and tokenizer with authentication
 
23
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH, use_auth_token=HF_TOKEN)
24
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_auth_token=HF_TOKEN)
25
  logging.info(f"Model and tokenizer loaded successfully from {MODEL_PATH}")
26
  except Exception as e:
27
- logging.error(f"Failed to load model or tokenizer: {e}")
28
- raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  # Define input model for FastAPI
31
  class ChecklistInput(BaseModel):
@@ -33,25 +64,52 @@ class ChecklistInput(BaseModel):
33
  project_id: str = "Unknown"
34
  project_name: str = "Unknown Project"
35
  milestones: str = "No milestones provided"
 
36
 
37
- # Initialize FastAPI for JSON endpoint (Gradio mounts it)
38
  app = FastAPI()
39
 
40
  @app.post("/generate")
41
  async def generate_checklist(data: ChecklistInput):
42
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  # Prepare input for the model
44
- inputs = f"Role: {data.role} Project: {data.project_id} ({data.project_name}) Milestones: {data.milestones}"
45
- logging.info(f"Generating checklist for inputs: {inputs[:100]}...") # Truncate for logging
46
 
47
  # Tokenize and generate
48
  input_ids = tokenizer(inputs, return_tensors="pt", max_length=128, truncation=True).input_ids
49
  outputs = model.generate(input_ids, max_length=128, num_beams=4, early_stopping=True)
50
  checklist = tokenizer.decode(outputs[0], skip_special_tokens=True)
51
 
52
- # Static tips (can be enhanced with model)
53
  tips = "1. Prioritize safety checks\n2. Review milestones\n3. Log progress"
54
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  logging.info("Checklist and tips generated successfully")
56
  return {
57
  "checklist": checklist,
@@ -62,13 +120,29 @@ async def generate_checklist(data: ChecklistInput):
62
  return {"error": str(e)}
63
 
64
  # Gradio interface function
65
- def gradio_generate_checklist(role, project_id, project_name, milestones):
66
  try:
 
 
 
 
 
 
 
 
67
  inputs = f"Role: {role} Project: {project_id} ({project_name}) Milestones: {milestones}"
68
  input_ids = tokenizer(inputs, return_tensors="pt", max_length=128, truncation=True).input_ids
69
  outputs = model.generate(input_ids, max_length=128, num_beams=4, early_stopping=True)
70
  checklist = tokenizer.decode(outputs[0], skip_special_tokens=True)
71
  tips = "1. Prioritize safety checks\n2. Review milestones\n3. Log progress"
 
 
 
 
 
 
 
 
72
  return checklist, tips
73
  except Exception as e:
74
  return f"Error: {str(e)}", ""
@@ -78,16 +152,17 @@ iface = gr.Interface(
78
  fn=gradio_generate_checklist,
79
  inputs=[
80
  gr.Textbox(label="Role", value="Supervisor"),
81
- gr.Textbox(label="Project ID", value="P001"),
82
  gr.Textbox(label="Project Name", value="Building A"),
83
- gr.Textbox(label="Milestones", value="Complete foundation by 5/15")
 
84
  ],
85
  outputs=[
86
  gr.Textbox(label="Checklist"),
87
  gr.Textbox(label="Tips")
88
  ],
89
  title="AI Coach for Site Supervisors",
90
- description="Generate daily checklists and tips for site supervisors."
91
  )
92
 
93
  # Mount FastAPI to Gradio
@@ -95,9 +170,8 @@ iface.app = app
95
 
96
  if __name__ == "__main__":
97
  try:
98
- # Launch Gradio with public URL for Hugging Face Spaces
99
  iface.launch(server_name="0.0.0.0", server_port=7860, share=False)
100
  logging.info("Gradio interface launched successfully")
101
  except Exception as e:
102
  logging.error(f"Failed to launch Gradio interface: {e}")
103
- raise e
 
 
1
  import gradio as gr
2
  import logging
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
  import os
5
+ from fastapi import FastAPI
6
  from pydantic import BaseModel
7
+ from simple_salesforce import Salesforce
8
+ import json
9
 
10
  # Configure logging
11
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
13
  # Load Hugging Face token from environment variable
14
  HF_TOKEN = os.getenv("HF_TOKEN")
15
  if not HF_TOKEN:
16
+ logging.warning("HF_TOKEN not set. Attempting to load public models only.")
17
 
18
  # Model configuration
19
+ MODEL_PATH = "your_username/fine_tuned_bart_construction" # Replace with your actual model path
20
+ FALLBACK_MODEL = "facebook/bart-large" # Fallback public model
21
 
22
  try:
23
+ # Try loading fine-tuned model
24
+ logging.info(f"Attempting to load model from {MODEL_PATH}")
25
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH, use_auth_token=HF_TOKEN)
26
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_auth_token=HF_TOKEN)
27
  logging.info(f"Model and tokenizer loaded successfully from {MODEL_PATH}")
28
  except Exception as e:
29
+ logging.error(f"Failed to load model from {MODEL_PATH}: {e}")
30
+ logging.info(f"Falling back to public model: {FALLBACK_MODEL}")
31
+ try:
32
+ model = AutoModelForSeq2SeqLM.from_pretrained(FALLBACK_MODEL)
33
+ tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL)
34
+ logging.info(f"Fallback model {FALLBACK_MODEL} loaded successfully")
35
+ except Exception as fallback_e:
36
+ logging.error(f"Failed to load fallback model {FALLBACK_MODEL}: {fallback_e}")
37
+ raise fallback_e
38
+
39
+ # Salesforce connection
40
+ def get_salesforce_connection():
41
+ try:
42
+ sf_username = os.getenv("SF_USERNAME")
43
+ sf_password = os.getenv("SF_PASSWORD")
44
+ sf_security_token = os.getenv("SF_SECURITY_TOKEN")
45
+
46
+ if not all([sf_username, sf_password, sf_security_token]):
47
+ logging.error("Salesforce credentials missing in environment variables")
48
+ raise ValueError("Salesforce credentials not configured")
49
+
50
+ sf = Salesforce(
51
+ username=sf_username,
52
+ password=sf_password,
53
+ security_token=sf_security_token
54
+ )
55
+ logging.info("Connected to Salesforce successfully")
56
+ return sf
57
+ except Exception as e:
58
+ logging.error(f"Failed to connect to Salesforce: {e}")
59
+ raise e
60
 
61
  # Define input model for FastAPI
62
  class ChecklistInput(BaseModel):
 
64
  project_id: str = "Unknown"
65
  project_name: str = "Unknown Project"
66
  milestones: str = "No milestones provided"
67
+ coaching_record_id: str = None # Optional Salesforce record ID
68
 
69
+ # Initialize FastAPI
70
  app = FastAPI()
71
 
72
  @app.post("/generate")
73
  async def generate_checklist(data: ChecklistInput):
74
  try:
75
+ # Connect to Salesforce
76
+ sf = get_salesforce_connection()
77
+
78
+ # Fetch milestones from Project__c if project_id is provided
79
+ milestones = data.milestones
80
+ if data.project_id != "Unknown":
81
+ try:
82
+ project = sf.query(f"SELECT Milestones__c FROM Project__c WHERE Id = '{data.project_id}'")
83
+ if project["totalSize"] > 0:
84
+ milestones = project["records"][0]["Milestones__c"] or milestones
85
+ logging.info(f"Fetched milestones for project {data.project_id}: {milestones}")
86
+ except Exception as e:
87
+ logging.warning(f"Failed to fetch project milestones: {e}")
88
+
89
  # Prepare input for the model
90
+ inputs = f"Role: {data.role} Project: {data.project_id} ({data.project_name}) Milestones: {milestones}"
91
+ logging.info(f"Generating checklist for inputs: {inputs[:100]}...")
92
 
93
  # Tokenize and generate
94
  input_ids = tokenizer(inputs, return_tensors="pt", max_length=128, truncation=True).input_ids
95
  outputs = model.generate(input_ids, max_length=128, num_beams=4, early_stopping=True)
96
  checklist = tokenizer.decode(outputs[0], skip_special_tokens=True)
97
 
98
+ # Static tips (replace with model-generated tips after fine-tuning)
99
  tips = "1. Prioritize safety checks\n2. Review milestones\n3. Log progress"
100
 
101
+ # Update Salesforce record if coaching_record_id is provided
102
+ if data.coaching_record_id:
103
+ try:
104
+ sf.Supervisor_AI_Coaching__c.update(data.coaching_record_id, {
105
+ "Daily_Checklist__c": checklist,
106
+ "Suggested_Tips__c": tips,
107
+ "Engagement_Score__c": 10 # Increment or set as needed
108
+ })
109
+ logging.info(f"Updated Salesforce record {data.coaching_record_id}")
110
+ except Exception as e:
111
+ logging.error(f"Failed to update Salesforce record: {e}")
112
+
113
  logging.info("Checklist and tips generated successfully")
114
  return {
115
  "checklist": checklist,
 
120
  return {"error": str(e)}
121
 
122
  # Gradio interface function
123
+ def gradio_generate_checklist(role, project_id, project_name, milestones, coaching_record_id=""):
124
  try:
125
+ sf = get_salesforce_connection()
126
+
127
+ # Fetch milestones if project_id is valid
128
+ if project_id:
129
+ project = sf.query(f"SELECT Milestones__c FROM Project__c WHERE Id = '{project_id}'")
130
+ if project["totalSize"] > 0:
131
+ milestones = project["records"][0]["Milestones__c"] or milestones
132
+
133
  inputs = f"Role: {role} Project: {project_id} ({project_name}) Milestones: {milestones}"
134
  input_ids = tokenizer(inputs, return_tensors="pt", max_length=128, truncation=True).input_ids
135
  outputs = model.generate(input_ids, max_length=128, num_beams=4, early_stopping=True)
136
  checklist = tokenizer.decode(outputs[0], skip_special_tokens=True)
137
  tips = "1. Prioritize safety checks\n2. Review milestones\n3. Log progress"
138
+
139
+ if coaching_record_id:
140
+ sf.Supervisor_AI_Coaching__c.update(coaching_record_id, {
141
+ "Daily_Checklist__c": checklist,
142
+ "Suggested_Tips__c": tips,
143
+ "Engagement_Score__c": 10
144
+ })
145
+
146
  return checklist, tips
147
  except Exception as e:
148
  return f"Error: {str(e)}", ""
 
152
  fn=gradio_generate_checklist,
153
  inputs=[
154
  gr.Textbox(label="Role", value="Supervisor"),
155
+ gr.Textbox(label="Project ID", value=""),
156
  gr.Textbox(label="Project Name", value="Building A"),
157
+ gr.Textbox(label="Milestones", value="Complete foundation by 5/15"),
158
+ gr.Textbox(label="Coaching Record ID", value="")
159
  ],
160
  outputs=[
161
  gr.Textbox(label="Checklist"),
162
  gr.Textbox(label="Tips")
163
  ],
164
  title="AI Coach for Site Supervisors",
165
+ description="Generate daily checklists and tips, with Salesforce integration."
166
  )
167
 
168
  # Mount FastAPI to Gradio
 
170
 
171
  if __name__ == "__main__":
172
  try:
 
173
  iface.launch(server_name="0.0.0.0", server_port=7860, share=False)
174
  logging.info("Gradio interface launched successfully")
175
  except Exception as e:
176
  logging.error(f"Failed to launch Gradio interface: {e}")
177
+ raise e