AjaykumarPilla commited on
Commit
d06b40d
·
verified ·
1 Parent(s): ea55c2f

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +22 -29
model.py CHANGED
@@ -25,62 +25,56 @@ def get_weather_condition(score: int) -> str:
25
 
26
  def call_ai_model_for_insights(input_data: Dict, delay_risk: float) -> List[str]:
27
  """
28
- Use DistilBART in Hugging Face Space (CPU) to generate insights based on input data and delay risk.
29
  """
30
- model_name = "sshleifer/distilbart-cnn-6-6"
31
  max_retries = 3
32
  retry_delay = 5 # seconds
33
 
34
  for attempt in range(max_retries):
35
  try:
36
  logger.info(f"Attempt {attempt + 1}/{max_retries} - Loading model: {model_name}")
37
- # Load tokenizer and model with minimal memory usage
38
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=False, use_fast=True)
39
  model = AutoModelForSeq2SeqLM.from_pretrained(
40
  model_name,
41
- torch_dtype=torch.float32, # CPU-compatible
42
- use_safetensors=True, # Secure loading
43
  trust_remote_code=False,
44
- low_cpu_mem_usage=True # Optimize for low memory
45
  )
46
 
47
  logger.info("Model loaded successfully. Generating insights...")
48
- # Prepare prompt
49
  prompt = f"""
50
- You are an AI assistant analyzing project delay risks for a construction project.
51
- Based on the following data, provide 2-4 concise insights or mitigation strategies as a list:
52
- - Project: {input_data.get('project_name', 'Unnamed Project')}
53
- - Phase: {input_data.get('phase', '')}
54
- - Task: {input_data.get('task', '')}
55
- - Expected Duration: {input_data.get('task_expected_duration', 0)} days
56
- - Actual Duration: {input_data.get('task_actual_duration', 0)} days
57
- - Current Progress: {input_data.get('current_progress', 0)}%
58
- - Workforce Gap: {input_data.get('workforce_gap', 0)}%
59
- - Workforce Skill Level: {input_data.get('workforce_skill_level', '').lower()}
60
- - Shift Hours: {input_data.get('workforce_shift_hours', 0)} hours
61
- - Weather Impact Score: {input_data.get('weather_impact_score', 0)} (Condition: {get_weather_condition(input_data.get('weather_impact_score', 0))})
62
- - Calculated Delay Risk: {delay_risk:.1f}%
63
-
64
- Format the response as a list of strings, e.g., ["Insight 1", "Insight 2"].
65
  """
66
 
67
- # Tokenize and generate with memory-efficient settings
68
  with torch.no_grad():
69
  inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to("cpu")
70
  outputs = model.generate(
71
  **inputs,
72
- max_new_tokens=100, # Reduced for faster CPU inference
73
  num_beams=4,
74
  temperature=0.7,
75
  do_sample=True
76
  )
77
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
78
 
79
- # Parse response into a list
80
  insights = [line.strip() for line in response.split("\n") if line.strip() and line.strip() not in [prompt]]
81
  logger.info(f"Generated insights: {insights}")
82
  return insights[:4] or ["No insights generated; review input data."]
83
-
84
  except Exception as e:
85
  logger.error(f"Attempt {attempt + 1}/{max_retries} - Model inference failed: {str(e)}")
86
  if attempt < max_retries - 1:
@@ -88,7 +82,6 @@ def call_ai_model_for_insights(input_data: Dict, delay_risk: float) -> List[str]
88
  time.sleep(retry_delay)
89
  else:
90
  logger.error("Max retries reached. Using fallback insights.")
91
- # Fallback: Generate basic rule-based insights
92
  fallback_insights = []
93
  if delay_risk > 75:
94
  fallback_insights.append("High risk detected; allocate additional resources urgently.")
@@ -98,13 +91,13 @@ def call_ai_model_for_insights(input_data: Dict, delay_risk: float) -> List[str]
98
  fallback_insights.append("Significant workforce gap; recruit additional workers.")
99
  if input_data.get('weather_impact_score', 0) > 50:
100
  fallback_insights.append("Adverse weather; prioritize indoor tasks.")
101
- return fallback_insights or ["AI model unavailable; monitor progress and resource allocation."]
102
 
103
  def predict_delay(input_data: Dict) -> Dict:
104
  """
105
  Predict delay probability based on project task data.
106
  Uses task duration, progress, workforce info, and weather impact.
107
- Insights are generated by DistilBART (CPU).
108
  """
109
  logger.info("Starting delay prediction")
110
  phase = input_data.get("phase", "")
 
25
 
26
  def call_ai_model_for_insights(input_data: Dict, delay_risk: float) -> List[str]:
27
  """
28
+ Use T5-Small in Hugging Face Space (CPU) to generate insights based on input data and delay risk.
29
  """
30
+ model_name = "t5-small"
31
  max_retries = 3
32
  retry_delay = 5 # seconds
33
 
34
  for attempt in range(max_retries):
35
  try:
36
  logger.info(f"Attempt {attempt + 1}/{max_retries} - Loading model: {model_name}")
 
37
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=False, use_fast=True)
38
  model = AutoModelForSeq2SeqLM.from_pretrained(
39
  model_name,
40
+ torch_dtype=torch.float32,
41
+ use_safetensors=True,
42
  trust_remote_code=False,
43
+ low_cpu_mem_usage=True
44
  )
45
 
46
  logger.info("Model loaded successfully. Generating insights...")
 
47
  prompt = f"""
48
+ Summarize the following project delay risk data into 2-4 concise insights or mitigation strategies as a list:
49
+ Project: {input_data.get('project_name', 'Unnamed Project')}
50
+ Phase: {input_data.get('phase', '')}
51
+ Task: {input_data.get('task', '')}
52
+ Expected Duration: {input_data.get('task_expected_duration', 0)} days
53
+ Actual Duration: {input_data.get('task_actual_duration', 0)} days
54
+ Current Progress: {input_data.get('current_progress', 0)}%
55
+ Workforce Gap: {input_data.get('workforce_gap', 0)}%
56
+ Workforce Skill Level: {input_data.get('workforce_skill_level', '').lower()}
57
+ Shift Hours: {input_data.get('workforce_shift_hours', 0)} hours
58
+ Weather Impact Score: {input_data.get('weather_impact_score', 0)} (Condition: {get_weather_condition(input_data.get('weather_impact_score', 0))})
59
+ Calculated Delay Risk: {delay_risk:.1f}%
60
+
61
+ Format the response as a list, e.g., ["Insight 1", "Insight 2"].
 
62
  """
63
 
 
64
  with torch.no_grad():
65
  inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to("cpu")
66
  outputs = model.generate(
67
  **inputs,
68
+ max_new_tokens=100,
69
  num_beams=4,
70
  temperature=0.7,
71
  do_sample=True
72
  )
73
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
74
 
 
75
  insights = [line.strip() for line in response.split("\n") if line.strip() and line.strip() not in [prompt]]
76
  logger.info(f"Generated insights: {insights}")
77
  return insights[:4] or ["No insights generated; review input data."]
 
78
  except Exception as e:
79
  logger.error(f"Attempt {attempt + 1}/{max_retries} - Model inference failed: {str(e)}")
80
  if attempt < max_retries - 1:
 
82
  time.sleep(retry_delay)
83
  else:
84
  logger.error("Max retries reached. Using fallback insights.")
 
85
  fallback_insights = []
86
  if delay_risk > 75:
87
  fallback_insights.append("High risk detected; allocate additional resources urgently.")
 
91
  fallback_insights.append("Significant workforce gap; recruit additional workers.")
92
  if input_data.get('weather_impact_score', 0) > 50:
93
  fallback_insights.append("Adverse weather; prioritize indoor tasks.")
94
+ return fallback_insights or ["AI model failed to generate insights; check system resources."]
95
 
96
  def predict_delay(input_data: Dict) -> Dict:
97
  """
98
  Predict delay probability based on project task data.
99
  Uses task duration, progress, workforce info, and weather impact.
100
+ Insights are generated by T5-Small (CPU).
101
  """
102
  logger.info("Starting delay prediction")
103
  phase = input_data.get("phase", "")