Nada commited on
Commit
5525e9f
·
1 Parent(s): 872c7e8
Files changed (2) hide show
  1. chatbot.py +41 -21
  2. requirements.txt +1 -0
chatbot.py CHANGED
@@ -6,6 +6,7 @@ import re
6
  from typing import List, Dict, Any, Optional, Union
7
  from datetime import datetime
8
  from pydantic import BaseModel, Field
 
9
 
10
  # Model imports
11
  from transformers import (
@@ -34,7 +35,7 @@ from conversation_flow import FlowManager
34
  logging.basicConfig(
35
  level=logging.INFO,
36
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
37
- handlers=[logging.StreamHandler()] # Only use StreamHandler for Hugging Face Spaces
38
  )
39
  logger = logging.getLogger(__name__)
40
 
@@ -42,12 +43,36 @@ logger = logging.getLogger(__name__)
42
  import warnings
43
  warnings.filterwarnings('ignore', category=UserWarning)
44
 
45
- # Set environment variables
46
- os.environ.update({
47
- 'TRANSFORMERS_VERBOSITY': 'error',
48
- 'TOKENIZERS_PARALLELISM': 'false',
49
- 'BITSANDBYTES_NOWELCOME': '1'
50
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  # Define base directory and paths
53
  BASE_DIR = os.path.abspath(os.path.dirname(__file__))
@@ -289,19 +314,17 @@ Response:"""
289
  model="SamLowe/roberta-base-go_emotions",
290
  top_k=None,
291
  device_map="auto" if self.device == "cuda" else None,
292
- local_files_only=False, # Force download from Hugging Face Hub
293
- use_auth_token=False # No authentication needed for public models
294
  )
295
  except Exception as e:
296
  logger.error(f"Error loading emotion model: {e}")
297
- # Fallback to a simpler model
298
  return pipeline(
299
  "text-classification",
300
  model="j-hartmann/emotion-english-distilroberta-base",
301
  return_all_scores=True,
302
  device_map="auto" if self.device == "cuda" else None,
303
- local_files_only=False,
304
- use_auth_token=False
305
  )
306
 
307
  def _initialize_llm(self, model_name: str, use_4bit: bool):
@@ -317,33 +340,30 @@ Response:"""
317
  else:
318
  quantization_config = None
319
 
320
- # Load base model directly from Hugging Face Hub
321
  logger.info(f"Loading base model: {model_name}")
322
  base_model = AutoModelForCausalLM.from_pretrained(
323
  model_name,
324
  quantization_config=quantization_config,
325
  device_map="auto" if self.device == "cuda" else None,
326
  trust_remote_code=True,
327
- local_files_only=False,
328
- use_auth_token=False
329
  )
330
 
331
- # Load tokenizer directly from Hugging Face Hub
332
  logger.info("Loading tokenizer")
333
  tokenizer = AutoTokenizer.from_pretrained(
334
  model_name,
335
- local_files_only=False,
336
- use_auth_token=False
337
  )
338
  tokenizer.pad_token = tokenizer.eos_token
339
 
340
- # Load PEFT model directly from Hugging Face Hub
341
  logger.info(f"Loading PEFT model from {self.peft_model_path}")
342
  model = PeftModel.from_pretrained(
343
  base_model,
344
  self.peft_model_path,
345
- local_files_only=False,
346
- use_auth_token=False
347
  )
348
  logger.info("Successfully loaded PEFT model")
349
 
 
6
  from typing import List, Dict, Any, Optional, Union
7
  from datetime import datetime
8
  from pydantic import BaseModel, Field
9
+ import tempfile
10
 
11
  # Model imports
12
  from transformers import (
 
35
  logging.basicConfig(
36
  level=logging.INFO,
37
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
38
+ handlers=[logging.StreamHandler()]
39
  )
40
  logger = logging.getLogger(__name__)
41
 
 
43
  import warnings
44
  warnings.filterwarnings('ignore', category=UserWarning)
45
 
46
+ # Set up cache directories
47
+ def setup_cache_dirs():
48
+ # Check if running in Hugging Face Spaces
49
+ is_spaces = os.environ.get('SPACE_ID') is not None
50
+
51
+ if is_spaces:
52
+ # Use /tmp for Hugging Face Spaces
53
+ cache_dir = '/tmp/huggingface'
54
+ os.environ.update({
55
+ 'TRANSFORMERS_CACHE': cache_dir,
56
+ 'HF_HOME': cache_dir,
57
+ 'TOKENIZERS_PARALLELISM': 'false',
58
+ 'TRANSFORMERS_VERBOSITY': 'error',
59
+ 'BITSANDBYTES_NOWELCOME': '1'
60
+ })
61
+ else:
62
+ # Use default cache for local development
63
+ cache_dir = os.path.expanduser('~/.cache/huggingface')
64
+ os.environ.update({
65
+ 'TOKENIZERS_PARALLELISM': 'false',
66
+ 'TRANSFORMERS_VERBOSITY': 'error',
67
+ 'BITSANDBYTES_NOWELCOME': '1'
68
+ })
69
+
70
+ # Create cache directory if it doesn't exist
71
+ os.makedirs(cache_dir, exist_ok=True)
72
+ return cache_dir
73
+
74
+ # Set up cache directories
75
+ CACHE_DIR = setup_cache_dirs()
76
 
77
  # Define base directory and paths
78
  BASE_DIR = os.path.abspath(os.path.dirname(__file__))
 
314
  model="SamLowe/roberta-base-go_emotions",
315
  top_k=None,
316
  device_map="auto" if self.device == "cuda" else None,
317
+ cache_dir=CACHE_DIR
 
318
  )
319
  except Exception as e:
320
  logger.error(f"Error loading emotion model: {e}")
321
+ # Fallback
322
  return pipeline(
323
  "text-classification",
324
  model="j-hartmann/emotion-english-distilroberta-base",
325
  return_all_scores=True,
326
  device_map="auto" if self.device == "cuda" else None,
327
+ cache_dir=CACHE_DIR
 
328
  )
329
 
330
  def _initialize_llm(self, model_name: str, use_4bit: bool):
 
340
  else:
341
  quantization_config = None
342
 
343
+ # Load base model
344
  logger.info(f"Loading base model: {model_name}")
345
  base_model = AutoModelForCausalLM.from_pretrained(
346
  model_name,
347
  quantization_config=quantization_config,
348
  device_map="auto" if self.device == "cuda" else None,
349
  trust_remote_code=True,
350
+ cache_dir=CACHE_DIR
 
351
  )
352
 
353
+ # Load tokenizer
354
  logger.info("Loading tokenizer")
355
  tokenizer = AutoTokenizer.from_pretrained(
356
  model_name,
357
+ cache_dir=CACHE_DIR
 
358
  )
359
  tokenizer.pad_token = tokenizer.eos_token
360
 
361
+ # Load PEFT model
362
  logger.info(f"Loading PEFT model from {self.peft_model_path}")
363
  model = PeftModel.from_pretrained(
364
  base_model,
365
  self.peft_model_path,
366
+ cache_dir=CACHE_DIR
 
367
  )
368
  logger.info("Successfully loaded PEFT model")
369
 
requirements.txt CHANGED
@@ -24,4 +24,5 @@ tokenizers>=0.21.1
24
  tiktoken>=0.9.0
25
  starlette>=0.46.1
26
  websockets>=15.0.1
 
27
  python-multipart>=0.0.6
 
24
  tiktoken>=0.9.0
25
  starlette>=0.46.1
26
  websockets>=15.0.1
27
+ tempfile>=0.1.0
28
  python-multipart>=0.0.6