Quartz4065 commited on
Commit
ae77e3b
·
verified ·
1 Parent(s): 50738b3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from typing import List, Optional
4
+
5
+ from fastapi import FastAPI, File, Form, UploadFile
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from pydantic import BaseModel
8
+ from faster_whisper import WhisperModel
9
+
10
+ APP_PORT = int(os.environ.get("PORT", "7860"))
11
+
12
+ _models = {}
13
+ def get_model(name: str):
14
+ if name not in _models:
15
+ _models[name] = WhisperModel(
16
+ name, compute_type="int8", cpu_threads=os.cpu_count() or 2
17
+ )
18
+ return _models[name]
19
+
20
+ class Segment(BaseModel):
21
+ start: float
22
+ end: float
23
+ text: str
24
+
25
+ class TranscribeOut(BaseModel):
26
+ text: str
27
+ segments: List[Segment]
28
+ duration_sec: Optional[float] = None
29
+ words: Optional[int] = None
30
+ wpm: Optional[float] = None
31
+ model: str
32
+
33
+ app = FastAPI(title="Nuvia Free Transcriber")
34
+ app.add_middleware(
35
+ CORSMiddleware,
36
+ allow_origins=["*"], allow_credentials=True,
37
+ allow_methods=["*"], allow_headers=["*"],
38
+ )
39
+
40
+ @app.get("/health")
41
+ def health():
42
+ return {"ok": True}
43
+
44
+ @app.post("/transcribe", response_model=TranscribeOut)
45
+ def transcribe(file: UploadFile = File(...), model: str = Form("base.en")):
46
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
47
+ tmp.write(file.file.read())
48
+ tmp_path = tmp.name
49
+
50
+ try:
51
+ m = get_model(model)
52
+ segments, info = m.transcribe(tmp_path, vad_filter=True)
53
+
54
+ segs = []
55
+ total_words = 0
56
+ for s in segments:
57
+ txt = s.text.strip()
58
+ segs.append(Segment(start=float(s.start), end=float(s.end), text=txt))
59
+ total_words += len(txt.split())
60
+
61
+ dur = float(info.duration) if getattr(info, "duration", None) else None
62
+ wpm = None
63
+ if dur and dur > 0:
64
+ wpm = round(total_words / (dur / 60.0), 2)
65
+
66
+ full_text = " ".join([s.text for s in segs]).strip()
67
+ return TranscribeOut(
68
+ text=full_text,
69
+ segments=segs,
70
+ duration_sec=dur,
71
+ words=total_words,
72
+ wpm=wpm,
73
+ model=model
74
+ )
75
+ finally:
76
+ try:
77
+ os.remove(tmp_path)
78
+ except Exception:
79
+ pass