Somalitts commited on
Commit
f9ea638
·
verified ·
1 Parent(s): 9e59401

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -74
app.py CHANGED
@@ -1,93 +1,47 @@
1
  import os
2
- import io
3
- import torch
4
- import torchaudio
5
  from fastapi import FastAPI, UploadFile, File
6
  from fastapi.middleware.cors import CORSMiddleware
 
 
7
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
 
8
 
9
- # --- Dejinta App-ka FastAPI ---
10
  app = FastAPI()
11
 
12
- # U oggolow dhammaan isku xidhka (CORS) si aad uga isticmaasho meelo kale sida Flutter
13
  app.add_middleware(
14
  CORSMiddleware,
15
- allow_origins=["*"], # Waxa aad ku beddeli kartaa domain-kaaga gaarka ah mustaqbalka
16
  allow_methods=["*"],
17
  allow_headers=["*"],
18
  )
19
 
20
- # --- Soo Dejinta Moodeelka (Model Loading) ---
21
- # Kani hadda waxa uu isticmaali doonaa jidka keydka (cache path) ee lagu dejiyay Dockerfile-ka
22
- # oo ah /app/hf-cache, kaas oo leh ruqsadaha saxda ah.
23
- # Faylkan waxa la isticmaalayaa oo kaliya inta lagu jiro dhismaha Docker
24
- # si loo soo dejiyo moodeelka loogana fogaado khaladaadka ruqsadaha ee runtime-ka
25
-
26
- from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
27
- import os
28
-
29
- MODEL_ID = "Mustafaa4a/ASR-Somali"
30
-
31
- print(f"Waxaa la bilaabayaa soo dejinta moodeelka: {MODEL_ID}")
32
- print(f"Lagu keydin doonaa galka: {os.environ.get('HF_HOME')}")
33
-
34
- # Labadan sadar ayaa kicin doona soo dejinta
35
- processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
36
- model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
37
-
38
- print("Soo dejinta moodeelka waa la dhammeystiray.")
39
- # --- API Endpoints ---
40
 
41
  @app.get("/")
42
  async def root():
43
- """
44
- Endpoint-ka asaasiga ah ee lagu hubinayo in API-gu shaqaynayo.
45
- """
46
- return {"message": "Somali Speech-to-Text API wuu shaqaynayaa."}
47
 
48
  @app.post("/transcribe")
49
  async def transcribe(file: UploadFile = File(...)):
50
- """
51
- Endpoint-ka qaabilaya faylka codka ah oo u beddelaya qoraal.
52
- """
53
- if not model or not processor:
54
- return {"error": "Moodeelka lama soo rarin, fadlan eeg log-yada server-ka si aad u ogaato khaladaadka."}
55
-
56
- try:
57
- # 1. Akhrinta codka la soo galiyay
58
- audio_bytes = await file.read()
59
- audio_stream = io.BytesIO(audio_bytes)
60
-
61
- # 2. Isticmaalka torchaudio si loogu beddelo waveform
62
- waveform, sample_rate = torchaudio.load(audio_stream)
63
-
64
- # --- HAGAAYNTA TAYADA CODKA ---
65
- # Tallaabooyinkani waa muhiim si loo helo natiijooyinka ugu fiican
66
-
67
- # 2a. U beddel sample rate-ka 16kHz (oo ah ka uu moodeelku u baahan yahay)
68
- if sample_rate != 16000:
69
- resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
70
- waveform = resampler(waveform)
71
-
72
- # 2b. U beddel hal kanaal (mono) adigoo isku celcelinaya haddii uu yahay stereo
73
- if waveform.shape[0] > 1:
74
- waveform = torch.mean(waveform, dim=0, keepdim=True)
75
- # --- DHAMAADKA HAGAAYNTA CODKA ---
76
-
77
- # 3. Farsamaynta waveform-ka si uu moodeelku u fahmo
78
- inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt")
79
-
80
- # 4. Isticmaalka moodeelka si codka loogu beddelo qoraal
81
- with torch.no_grad():
82
- logits = model(**inputs).logits
83
-
84
- # 5. Soo saarista qoraalka ugu macquulsan
85
- predicted_ids = torch.argmax(logits, dim=-1)
86
- transcription = processor.decode(predicted_ids[0])
87
-
88
- # 6. Soo celinta natiijada
89
- return {"transcription": transcription.lower()}
90
-
91
- except Exception as e:
92
- # Haddii khalad dhaco inta lagu jiro farsamaynta, soo celi fariin khalad ah
93
- return {"error": f"Khalad ayaa dhacay intii lagu jiray qoraal-u-beddelidda: {str(e)}"}
 
1
  import os
2
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf-cache" # Important for Docker
3
+
 
4
  from fastapi import FastAPI, UploadFile, File
5
  from fastapi.middleware.cors import CORSMiddleware
6
+ import torchaudio
7
+ import torch
8
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
9
+ import io
10
 
 
11
  app = FastAPI()
12
 
13
+ # Allow all origins (for Flutter)
14
  app.add_middleware(
15
  CORSMiddleware,
16
+ allow_origins=["*"],
17
  allow_methods=["*"],
18
  allow_headers=["*"],
19
  )
20
 
21
+ # Load model
22
+ processor = Wav2Vec2Processor.from_pretrained("Mustafaa4a/ASR-Somali")
23
+ model = Wav2Vec2ForCTC.from_pretrained("Mustafaa4a/ASR-Somali")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  @app.get("/")
26
  async def root():
27
+ return {"message": "Somali Speech-to-Text API is running."}
 
 
 
28
 
29
  @app.post("/transcribe")
30
  async def transcribe(file: UploadFile = File(...)):
31
+ audio_bytes = await file.read()
32
+ audio_stream = io.BytesIO(audio_bytes)
33
+
34
+ waveform, sample_rate = torchaudio.load(audio_stream)
35
+
36
+ if sample_rate != 16000:
37
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
38
+ waveform = resampler(waveform)
39
+
40
+ inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt")
41
+
42
+ with torch.no_grad():
43
+ logits = model(**inputs).logits
44
+
45
+ predicted_ids = torch.argmax(logits, dim=-1)
46
+ transcription = processor.decode(predicted_ids[0])
47
+ return {"transcription": transcription}