Prashant26am commited on
Commit
96319a3
·
1 Parent(s): 6fc2918

Fix: Add writable cache directory for model downloads

Browse files
Files changed (2) hide show
  1. Dockerfile +3 -0
  2. app.py +13 -6
Dockerfile CHANGED
@@ -9,6 +9,9 @@ RUN apt-get update && apt-get install -y \
9
  build-essential \
10
  && rm -rf /var/lib/apt/lists/*
11
 
 
 
 
12
  COPY requirements.txt .
13
  RUN pip install --no-cache-dir -r requirements.txt
14
 
 
9
  build-essential \
10
  && rm -rf /var/lib/apt/lists/*
11
 
12
+ # Create cache directory with proper permissions
13
+ RUN mkdir -p /code/.cache && chmod 777 /code/.cache
14
+
15
  COPY requirements.txt .
16
  RUN pip install --no-cache-dir -r requirements.txt
17
 
app.py CHANGED
@@ -5,37 +5,44 @@ import tempfile
5
  import os
6
  from transformers import pipeline
7
  import logging
 
8
 
9
  # Set up logging
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
 
 
 
 
 
 
13
  app = FastAPI(title="TranscriptoCast AI (Demo)")
14
 
15
  # Load models once at startup
16
  try:
17
  logger.info("Loading Whisper model...")
18
- whisper_model = whisper.load_model("base")
19
  logger.info("Whisper model loaded successfully")
20
  except Exception as e:
21
  logger.error(f"Error loading Whisper model: {str(e)}")
22
- raise HTTPException(status_code=500, detail="Failed to load Whisper model")
23
 
24
  try:
25
  logger.info("Loading summarization model...")
26
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
27
  logger.info("Summarization model loaded successfully")
28
  except Exception as e:
29
  logger.error(f"Error loading summarization model: {str(e)}")
30
- raise HTTPException(status_code=500, detail="Failed to load summarization model")
31
 
32
  try:
33
  logger.info("Loading translation model...")
34
- translator = pipeline("translation", model="facebook/mbart-large-50-many-to-many-mmt")
35
  logger.info("Translation model loaded successfully")
36
  except Exception as e:
37
  logger.error(f"Error loading translation model: {str(e)}")
38
- raise HTTPException(status_code=500, detail="Failed to load translation model")
39
 
40
  @app.post("/transcribe")
41
  async def transcribe(file: UploadFile = File(...)):
 
5
  import os
6
  from transformers import pipeline
7
  import logging
8
+ from pathlib import Path
9
 
10
  # Set up logging
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
14
+ # Create cache directory in the workspace
15
+ CACHE_DIR = Path("/code/.cache")
16
+ CACHE_DIR.mkdir(exist_ok=True)
17
+ os.environ["TRANSFORMERS_CACHE"] = str(CACHE_DIR)
18
+ os.environ["HF_HOME"] = str(CACHE_DIR)
19
+
20
  app = FastAPI(title="TranscriptoCast AI (Demo)")
21
 
22
  # Load models once at startup
23
  try:
24
  logger.info("Loading Whisper model...")
25
+ whisper_model = whisper.load_model("base", download_root=str(CACHE_DIR))
26
  logger.info("Whisper model loaded successfully")
27
  except Exception as e:
28
  logger.error(f"Error loading Whisper model: {str(e)}")
29
+ raise HTTPException(status_code=500, detail=f"Failed to load Whisper model: {str(e)}")
30
 
31
  try:
32
  logger.info("Loading summarization model...")
33
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn", cache_dir=str(CACHE_DIR))
34
  logger.info("Summarization model loaded successfully")
35
  except Exception as e:
36
  logger.error(f"Error loading summarization model: {str(e)}")
37
+ raise HTTPException(status_code=500, detail=f"Failed to load summarization model: {str(e)}")
38
 
39
  try:
40
  logger.info("Loading translation model...")
41
+ translator = pipeline("translation", model="facebook/mbart-large-50-many-to-many-mmt", cache_dir=str(CACHE_DIR))
42
  logger.info("Translation model loaded successfully")
43
  except Exception as e:
44
  logger.error(f"Error loading translation model: {str(e)}")
45
+ raise HTTPException(status_code=500, detail=f"Failed to load translation model: {str(e)}")
46
 
47
  @app.post("/transcribe")
48
  async def transcribe(file: UploadFile = File(...)):