MIP-Tech Claude Sonnet 4.6 commited on
Commit
9308938
·
1 Parent(s): 27216ff

Fix deprecation warnings, CPU batch size, and root route

Browse files

- torch_dtype → dtype in from_pretrained and pipeline (transformers 4.49+)
- Move max_new_tokens to generate_kwargs in pipeline constructor to
silence generation_config conflict warning
- batch_size: 16 → 2 on CPU (16 parallel chunks wastes RAM, not faster),
keep 8 on CUDA
- Add GET / → redirect to /docs so HF Space health probes return 200
instead of 404

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. api/main.py +6 -1
  2. src/inference/transcribe.py +8 -5
api/main.py CHANGED
@@ -4,7 +4,7 @@ from contextlib import asynccontextmanager
4
  from fastapi import FastAPI, Request
5
  from fastapi.concurrency import run_in_threadpool
6
  from fastapi.middleware.cors import CORSMiddleware
7
- from fastapi.responses import JSONResponse
8
 
9
  from api.config import settings
10
  from api.routers.transcription import router as transcription_router
@@ -81,6 +81,11 @@ app.add_middleware(
81
  app.include_router(transcription_router)
82
 
83
 
 
 
 
 
 
84
  @app.get("/health", response_model=HealthResponse, tags=["system"])
85
  async def health(request: Request) -> HealthResponse:
86
  transcriber = getattr(request.app.state, "transcriber", None)
 
4
  from fastapi import FastAPI, Request
5
  from fastapi.concurrency import run_in_threadpool
6
  from fastapi.middleware.cors import CORSMiddleware
7
+ from fastapi.responses import JSONResponse, RedirectResponse
8
 
9
  from api.config import settings
10
  from api.routers.transcription import router as transcription_router
 
81
  app.include_router(transcription_router)
82
 
83
 
84
+ @app.get("/", include_in_schema=False)
85
+ async def root() -> RedirectResponse:
86
+ return RedirectResponse(url="/docs")
87
+
88
+
89
  @app.get("/health", response_model=HealthResponse, tags=["system"])
90
  async def health(request: Request) -> HealthResponse:
91
  transcriber = getattr(request.app.state, "transcriber", None)
src/inference/transcribe.py CHANGED
@@ -24,23 +24,26 @@ class WhisperTranscriber:
24
  try:
25
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
26
  self.processor = AutoProcessor.from_pretrained(model_path)
 
27
  self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
28
  model_path,
29
- torch_dtype=torch.float16 if "cuda" in self.device else torch.float32,
30
  low_cpu_mem_usage=True,
31
  ).to(self.device)
32
-
 
 
33
  self.pipe = pipeline(
34
  "automatic-speech-recognition",
35
  model=self.model,
36
  tokenizer=self.processor.tokenizer,
37
  feature_extractor=self.processor.feature_extractor,
38
- max_new_tokens=128,
39
  chunk_length_s=30,
40
- batch_size=16,
41
  return_timestamps=True,
42
- torch_dtype=torch.float16 if "cuda" in self.device else torch.float32,
43
  device=self.device,
 
44
  )
45
  except Exception as e:
46
  logger.error("Failed to load Whisper backend: %s", e)
 
24
  try:
25
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
26
  self.processor = AutoProcessor.from_pretrained(model_path)
27
+ dtype = torch.float16 if "cuda" in self.device else torch.float32
28
  self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
29
  model_path,
30
+ dtype=dtype,
31
  low_cpu_mem_usage=True,
32
  ).to(self.device)
33
+
34
+ # batch_size=16 is only useful on GPU; CPU benefits from 1-2 chunks at a time
35
+ batch_size = 8 if "cuda" in self.device else 2
36
  self.pipe = pipeline(
37
  "automatic-speech-recognition",
38
  model=self.model,
39
  tokenizer=self.processor.tokenizer,
40
  feature_extractor=self.processor.feature_extractor,
 
41
  chunk_length_s=30,
42
+ batch_size=batch_size,
43
  return_timestamps=True,
44
+ dtype=dtype,
45
  device=self.device,
46
+ generate_kwargs={"max_new_tokens": 128},
47
  )
48
  except Exception as e:
49
  logger.error("Failed to load Whisper backend: %s", e)