AjaykumarPilla commited on
Commit
a283630
·
verified ·
1 Parent(s): 64f83b4

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +32 -11
model.py CHANGED
@@ -1,7 +1,12 @@
1
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
2
  import torch
 
3
  from typing import Dict, List
4
 
 
 
 
 
5
  def get_weather_condition(score: int) -> str:
6
  """Map weather impact score (0-100) to descriptive weather condition."""
7
  if score <= 10:
@@ -23,15 +28,18 @@ def call_ai_model_for_insights(input_data: Dict, delay_risk: float) -> List[str]
23
  """
24
  model_name = "sshleifer/distilbart-cnn-6-6"
25
  try:
26
- # Load tokenizer and model for CPU
27
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=False)
 
28
  model = AutoModelForSeq2SeqLM.from_pretrained(
29
  model_name,
30
- torch_dtype=torch.float32, # Use float32 for CPU
31
- use_safetensors=True, # Ensure safe loading
32
- trust_remote_code=False
 
33
  )
34
 
 
35
  # Prepare prompt
36
  prompt = f"""
37
  You are an AI assistant analyzing project delay risks for a construction project.
@@ -51,13 +59,13 @@ def call_ai_model_for_insights(input_data: Dict, delay_risk: float) -> List[str]
51
  Format the response as a list of strings, e.g., ["Insight 1", "Insight 2"].
52
  """
53
 
54
- # Tokenize and generate with no_grad for memory efficiency
55
  with torch.no_grad():
56
  inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to("cpu")
57
  outputs = model.generate(
58
  **inputs,
59
- max_new_tokens=150, # Smaller output for CPU efficiency
60
- num_beams=4, # Beam search for better quality
61
  temperature=0.7,
62
  do_sample=True
63
  )
@@ -65,11 +73,22 @@ def call_ai_model_for_insights(input_data: Dict, delay_risk: float) -> List[str]
65
 
66
  # Parse response into a list
67
  insights = [line.strip() for line in response.split("\n") if line.strip() and line.strip() not in [prompt]]
68
- return insights[:4] # Limit to 2-4 insights
 
69
 
70
  except Exception as e:
71
- print(f"Error with model inference: {e}")
72
- return ["AI model unavailable; monitor progress and resource allocation."]
 
 
 
 
 
 
 
 
 
 
73
 
74
  def predict_delay(input_data: Dict) -> Dict:
75
  """
@@ -77,6 +96,7 @@ def predict_delay(input_data: Dict) -> Dict:
77
  Uses task duration, progress, workforce info, and weather impact.
78
  Insights are generated by DistilBART (CPU).
79
  """
 
80
  phase = input_data.get("phase", "")
81
  task = input_data.get("task", "")
82
  expected_duration = input_data.get("task_expected_duration", 0)
@@ -153,6 +173,7 @@ def predict_delay(input_data: Dict) -> Dict:
153
  # Generate AI-driven insights
154
  insights = call_ai_model_for_insights(input_data, delay_risk)
155
 
 
156
  return {
157
  "project": input_data.get("project_name", "Unnamed Project"),
158
  "phase": phase,
 
1
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
2
  import torch
3
+ import logging
4
  from typing import Dict, List
5
 
6
+ # Configure logging
7
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
8
+ logger = logging.getLogger(__name__)
9
+
10
  def get_weather_condition(score: int) -> str:
11
  """Map weather impact score (0-100) to descriptive weather condition."""
12
  if score <= 10:
 
28
  """
29
  model_name = "sshleifer/distilbart-cnn-6-6"
30
  try:
31
+ logger.info(f"Loading model: {model_name}")
32
+ # Load tokenizer and model with minimal memory usage
33
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=False, use_fast=True)
34
  model = AutoModelForSeq2SeqLM.from_pretrained(
35
  model_name,
36
+ torch_dtype=torch.float32, # CPU-compatible
37
+ use_safetensors=True, # Secure loading
38
+ trust_remote_code=False,
39
+ low_cpu_mem_usage=True # Optimize for low memory
40
  )
41
 
42
+ logger.info("Model loaded successfully. Generating insights...")
43
  # Prepare prompt
44
  prompt = f"""
45
  You are an AI assistant analyzing project delay risks for a construction project.
 
59
  Format the response as a list of strings, e.g., ["Insight 1", "Insight 2"].
60
  """
61
 
62
+ # Tokenize and generate with memory-efficient settings
63
  with torch.no_grad():
64
  inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to("cpu")
65
  outputs = model.generate(
66
  **inputs,
67
+ max_new_tokens=100, # Reduced for faster CPU inference
68
+ num_beams=4,
69
  temperature=0.7,
70
  do_sample=True
71
  )
 
73
 
74
  # Parse response into a list
75
  insights = [line.strip() for line in response.split("\n") if line.strip() and line.strip() not in [prompt]]
76
+ logger.info(f"Generated insights: {insights}")
77
+ return insights[:4] or ["No insights generated; review input data."]
78
 
79
  except Exception as e:
80
+ logger.error(f"Model inference failed: {str(e)}")
81
+ # Fallback: Generate basic rule-based insights
82
+ fallback_insights = []
83
+ if delay_risk > 75:
84
+ fallback_insights.append("High risk detected; allocate additional resources urgently.")
85
+ elif delay_risk > 50:
86
+ fallback_insights.append("Moderate risk; consider extending shift hours or hiring staff.")
87
+ if input_data.get('workforce_gap', 0) > 20:
88
+ fallback_insights.append("Significant workforce gap; recruit additional workers.")
89
+ if input_data.get('weather_impact_score', 0) > 50:
90
+ fallback_insights.append("Adverse weather; prioritize indoor tasks.")
91
+ return fallback_insights or ["AI model unavailable; monitor progress and resource allocation."]
92
 
93
  def predict_delay(input_data: Dict) -> Dict:
94
  """
 
96
  Uses task duration, progress, workforce info, and weather impact.
97
  Insights are generated by DistilBART (CPU).
98
  """
99
+ logger.info("Starting delay prediction")
100
  phase = input_data.get("phase", "")
101
  task = input_data.get("task", "")
102
  expected_duration = input_data.get("task_expected_duration", 0)
 
173
  # Generate AI-driven insights
174
  insights = call_ai_model_for_insights(input_data, delay_risk)
175
 
176
+ logger.info(f"Prediction completed: Delay risk = {delay_risk:.1f}%")
177
  return {
178
  "project": input_data.get("project_name", "Unnamed Project"),
179
  "phase": phase,