Amol Kaushik commited on
Commit
f15dcea
·
1 Parent(s): fc08c1d

Download models from Google Drive at runtime

Browse files
Files changed (2) hide show
  1. app.py +42 -0
  2. requirements.txt +1 -0
app.py CHANGED
@@ -2,15 +2,39 @@ import gradio as gr
2
  import pandas as pd
3
  import pickle
4
  import os
 
5
 
6
  # Get directory where this script is located
7
  SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
8
 
 
 
 
 
9
  # Local paths - models loaded from A3/models/ directory
10
  MODEL_PATH = os.path.join(SCRIPT_DIR, "A3/models/champion_model_final_2.pkl")
11
  CLASSIFICATION_MODEL_PATH = os.path.join(SCRIPT_DIR, "A3/models/final_champion_model_A3.pkl")
12
  DATA_PATH = os.path.join(SCRIPT_DIR, "A3/A3_Data/train_dataset.csv")
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  model = None
15
  FEATURE_NAMES = None
16
  MODEL_METRICS = None
@@ -30,6 +54,15 @@ BODY_REGION_RECOMMENDATIONS = {
30
  def load_champion_model():
31
  global model, FEATURE_NAMES, MODEL_METRICS
32
 
 
 
 
 
 
 
 
 
 
33
  if os.path.exists(MODEL_PATH):
34
  print(f"Loading champion model from {MODEL_PATH}")
35
  with open(MODEL_PATH, "rb") as f:
@@ -50,6 +83,15 @@ def load_champion_model():
50
  def load_classification_model():
51
  global classification_model, CLASSIFICATION_FEATURE_NAMES, CLASSIFICATION_CLASSES, CLASSIFICATION_METRICS
52
 
 
 
 
 
 
 
 
 
 
53
  if os.path.exists(CLASSIFICATION_MODEL_PATH):
54
  print(f"Loading classification model from {CLASSIFICATION_MODEL_PATH}")
55
  with open(CLASSIFICATION_MODEL_PATH, "rb") as f:
 
2
  import pandas as pd
3
  import pickle
4
  import os
5
+ import requests
6
 
7
  # Get directory where this script is located
8
  SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
9
 
10
+ # Google Drive file IDs for model downloads
11
+ MODEL_GDRIVE_ID = "1ORlU0OOCBkWXVO2UFAkXaKtXfkOH7w1t" # champion_model_final_2.pkl
12
+ CLASSIFICATION_MODEL_GDRIVE_ID = "1QYVd9sHZbI4Vp21bO2Zd1vTcRpcq9wJs" # final_champion_model_A3.pkl
13
+
14
  # Local paths - models loaded from A3/models/ directory
15
  MODEL_PATH = os.path.join(SCRIPT_DIR, "A3/models/champion_model_final_2.pkl")
16
  CLASSIFICATION_MODEL_PATH = os.path.join(SCRIPT_DIR, "A3/models/final_champion_model_A3.pkl")
17
  DATA_PATH = os.path.join(SCRIPT_DIR, "A3/A3_Data/train_dataset.csv")
18
 
19
+
20
+ def download_from_gdrive(file_id, destination):
21
+ """Download a file from Google Drive."""
22
+ URL = "https://drive.google.com/uc?export=download"
23
+
24
+ session = requests.Session()
25
+ response = session.get(URL, params={'id': file_id, 'confirm': 't'}, stream=True)
26
+
27
+ # Create directory if needed
28
+ os.makedirs(os.path.dirname(destination), exist_ok=True)
29
+
30
+ with open(destination, "wb") as f:
31
+ for chunk in response.iter_content(chunk_size=32768):
32
+ if chunk:
33
+ f.write(chunk)
34
+
35
+ print(f"Downloaded to {destination}")
36
+ return True
37
+
38
  model = None
39
  FEATURE_NAMES = None
40
  MODEL_METRICS = None
 
54
  def load_champion_model():
55
  global model, FEATURE_NAMES, MODEL_METRICS
56
 
57
+ # Download from Google Drive if not exists locally
58
+ if not os.path.exists(MODEL_PATH):
59
+ print(f"Model not found locally, downloading from Google Drive...")
60
+ try:
61
+ download_from_gdrive(MODEL_GDRIVE_ID, MODEL_PATH)
62
+ except Exception as e:
63
+ print(f"Failed to download model: {e}")
64
+ return False
65
+
66
  if os.path.exists(MODEL_PATH):
67
  print(f"Loading champion model from {MODEL_PATH}")
68
  with open(MODEL_PATH, "rb") as f:
 
83
  def load_classification_model():
84
  global classification_model, CLASSIFICATION_FEATURE_NAMES, CLASSIFICATION_CLASSES, CLASSIFICATION_METRICS
85
 
86
+ # Download from Google Drive if not exists locally
87
+ if not os.path.exists(CLASSIFICATION_MODEL_PATH):
88
+ print(f"Classification model not found locally, downloading from Google Drive...")
89
+ try:
90
+ download_from_gdrive(CLASSIFICATION_MODEL_GDRIVE_ID, CLASSIFICATION_MODEL_PATH)
91
+ except Exception as e:
92
+ print(f"Failed to download classification model: {e}")
93
+ return False
94
+
95
  if os.path.exists(CLASSIFICATION_MODEL_PATH):
96
  print(f"Loading classification model from {CLASSIFICATION_MODEL_PATH}")
97
  with open(CLASSIFICATION_MODEL_PATH, "rb") as f:
requirements.txt CHANGED
@@ -4,3 +4,4 @@ numpy>=1.24.0
4
  scikit-learn==1.7.2
5
  statsmodels>=0.14.0
6
  matplotlib>=3.7.0
 
 
4
  scikit-learn==1.7.2
5
  statsmodels>=0.14.0
6
  matplotlib>=3.7.0
7
+ requests>=2.28.0