b2u commited on
Commit
cc4cd30
·
1 Parent(s): a29d97f

removing config.json and moving the setting to the app files

Browse files
Files changed (5) hide show
  1. .dockerignore +0 -1
  2. _wsgi.py +0 -17
  3. config.json +0 -20
  4. docker-compose.yml +17 -14
  5. model.py +13 -19
.dockerignore CHANGED
@@ -7,7 +7,6 @@
7
  !Dockerfile
8
  !docker-compose.yml
9
  !*.sh
10
- !config.json
11
 
12
  # Include any other necessary files
13
  !model/**
 
7
  !Dockerfile
8
  !docker-compose.yml
9
  !*.sh
 
10
 
11
  # Include any other necessary files
12
  !model/**
_wsgi.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import json
3
  import logging
4
  import logging.config
5
  from pathlib import Path
@@ -35,25 +34,9 @@ logging.config.dictConfig({
35
 
36
  logger = logging.getLogger(__name__)
37
 
38
- def get_config():
39
- """Load configuration from config.json"""
40
- config_path = os.path.join(os.path.dirname(__file__), 'config.json')
41
- if not os.path.exists(config_path):
42
- logger.warning(f"Config file not found at {config_path}, using default settings")
43
- return {}
44
-
45
- try:
46
- with open(config_path) as f:
47
- config = json.load(f)
48
- return config
49
- except Exception as e:
50
- logger.error(f"Error loading config: {str(e)}")
51
- return {}
52
-
53
  # Initialize the app at module level for Gunicorn
54
  app = init_app(
55
  model_class=T5Model,
56
- config=get_config(),
57
  basic_auth_user=os.environ.get('BASIC_AUTH_USER'),
58
  basic_auth_pass=os.environ.get('BASIC_AUTH_PASS')
59
  )
 
1
  import os
 
2
  import logging
3
  import logging.config
4
  from pathlib import Path
 
34
 
35
  logger = logging.getLogger(__name__)
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # Initialize the app at module level for Gunicorn
38
  app = init_app(
39
  model_class=T5Model,
 
40
  basic_auth_user=os.environ.get('BASIC_AUTH_USER'),
41
  basic_auth_pass=os.environ.get('BASIC_AUTH_PASS')
42
  )
config.json DELETED
@@ -1,20 +0,0 @@
1
- {
2
- "model": {
3
- "name": "google/flan-t5-base",
4
- "max_length": 512,
5
- "generation_max_length": 128,
6
- "num_return_sequences": 1
7
- },
8
- "lora": {
9
- "r": 8,
10
- "alpha": 32,
11
- "dropout": 0.1,
12
- "target_modules": ["q", "v"]
13
- },
14
- "training": {
15
- "learning_rate": 1e-4,
16
- "batch_size": 1,
17
- "max_steps": 100,
18
- "save_steps": 50
19
- }
20
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docker-compose.yml CHANGED
@@ -9,26 +9,29 @@ services:
9
  args:
10
  TEST_ENV: ${TEST_ENV}
11
  environment:
12
- # specify these parameters if you want to use basic auth for the model server
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  - BASIC_AUTH_USER=
14
  - BASIC_AUTH_PASS=
15
- # set the log level for the model server
16
  - LOG_LEVEL=DEBUG
17
- # any other parameters that you want to pass to the model server
18
- - ANY=PARAMETER
19
- # specify the number of workers and threads for the model server
20
  - WORKERS=1
21
  - THREADS=8
22
- # specify the model directory (likely you don't need to change this)
23
  - MODEL_DIR=/data/models
24
-
25
- # Specify the Label Studio URL and API key to access
26
- # uploaded, local storage and cloud storage files.
27
- # Do not use 'localhost' as it does not work within Docker containers.
28
- # Use prefix 'http://' or 'https://' for the URL always.
29
- # Determine the actual IP using 'ifconfig' (Linux/Mac) or 'ipconfig' (Windows).
30
- - LABEL_STUDIO_URL=
31
- - LABEL_STUDIO_API_KEY=
32
  ports:
33
  - "9090:9090"
34
  volumes:
 
9
  args:
10
  TEST_ENV: ${TEST_ENV}
11
  environment:
12
+ # Model settings
13
+ - MODEL_NAME=google/flan-t5-base
14
+ - MAX_LENGTH=512
15
+ - GENERATION_MAX_LENGTH=128
16
+ - NUM_RETURN_SEQUENCES=1
17
+ # LoRA settings
18
+ - LORA_R=8
19
+ - LORA_ALPHA=32
20
+ - LORA_DROPOUT=0.1
21
+ - LORA_TARGET_MODULES=q,v
22
+ # Training settings
23
+ - LEARNING_RATE=1e-4
24
+ - BATCH_SIZE=1
25
+ - MAX_STEPS=100
26
+ - SAVE_STEPS=50
27
+ # Other settings
28
  - BASIC_AUTH_USER=
29
  - BASIC_AUTH_PASS=
 
30
  - LOG_LEVEL=DEBUG
 
 
 
31
  - WORKERS=1
32
  - THREADS=8
 
33
  - MODEL_DIR=/data/models
34
+ - HF_CHECKPOINT_DIR=/data/checkpoints
 
 
 
 
 
 
 
35
  ports:
36
  - "9090:9090"
37
  volumes:
model.py CHANGED
@@ -15,32 +15,26 @@ class T5Model(LabelStudioMLBase):
15
  def __init__(self, **kwargs):
16
  super(T5Model, self).__init__(**kwargs)
17
 
18
- # Get configuration from kwargs (loaded from config.json if it exists)
19
- config = kwargs.get('config', {})
20
- model_config = config.get('model', {})
21
- lora_config = config.get('lora', {})
22
- training_config = config.get('training', {})
23
-
24
- # Model settings
25
- self.model_name = model_config.get('name', "google/flan-t5-base")
26
- self.max_length = model_config.get('max_length', 512)
27
- self.generation_max_length = model_config.get('generation_max_length', 128)
28
- self.num_return_sequences = model_config.get('num_return_sequences', 1)
29
 
30
  # LoRA settings
31
  self.lora_config = {
32
- "r": lora_config.get('r', 8),
33
- "alpha": lora_config.get('alpha', 32),
34
- "dropout": lora_config.get('dropout', 0.1),
35
- "target_modules": lora_config.get('target_modules', ["q", "v"])
36
  }
37
 
38
  # Training settings
39
  self.training_config = {
40
- "learning_rate": training_config.get('learning_rate', 1e-4),
41
- "batch_size": training_config.get('batch_size', 1),
42
- "max_steps": training_config.get('max_steps', 100),
43
- "save_steps": training_config.get('save_steps', 50)
44
  }
45
 
46
  # Model components
 
15
  def __init__(self, **kwargs):
16
  super(T5Model, self).__init__(**kwargs)
17
 
18
+ # Model settings from environment variables
19
+ self.model_name = os.getenv('MODEL_NAME', 'google/flan-t5-base')
20
+ self.max_length = int(os.getenv('MAX_LENGTH', '512'))
21
+ self.generation_max_length = int(os.getenv('GENERATION_MAX_LENGTH', '128'))
22
+ self.num_return_sequences = int(os.getenv('NUM_RETURN_SEQUENCES', '1'))
 
 
 
 
 
 
23
 
24
  # LoRA settings
25
  self.lora_config = {
26
+ "r": int(os.getenv('LORA_R', '8')),
27
+ "alpha": int(os.getenv('LORA_ALPHA', '32')),
28
+ "dropout": float(os.getenv('LORA_DROPOUT', '0.1')),
29
+ "target_modules": os.getenv('LORA_TARGET_MODULES', 'q,v').split(',')
30
  }
31
 
32
  # Training settings
33
  self.training_config = {
34
+ "learning_rate": float(os.getenv('LEARNING_RATE', '1e-4')),
35
+ "batch_size": int(os.getenv('BATCH_SIZE', '1')),
36
+ "max_steps": int(os.getenv('MAX_STEPS', '100')),
37
+ "save_steps": int(os.getenv('SAVE_STEPS', '50'))
38
  }
39
 
40
  # Model components