ug-asr-api / app.py
Piyazon
changed domain asr
3cbcb62
import os
import torch
import torchaudio
import torchcodec
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from transformers import Wav2Vec2BertProcessor, AutoModelForCTC
from pydub import AudioSegment
import tempfile
import io
from deepmultilingualpunctuation import PunctuationModel
app = FastAPI(title="Uyghur Speech to Text API")
# Allow specific domains or all (*) for testing
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# @app.get("/")
# def greet_json():
# return {"Hello": "World!"}
# @app.get("/")
# def greet_json():
# return {
# "URL: ": """<a href="https://transcriber.piyazon.top/">https://transcriber.piyazon.top/</a>"""
# }
@app.get("/", response_class=HTMLResponse)
def greet_html():
return """
<html>
<body>
<h1>
URL1:
<a href="https://asr.piyazon.top">https://asr.piyazon.top</a>
</h1>
<h1>
URL2:
<a href="https://transcriber.piyazon.top">https://transcriber.piyazon.top</a>
</h1>
</body>
</html>
"""
# Available Wav2Vec2 models
MODEL_OPTIONS = [
"piyazon/ASR-cv-corpus-ug-11",
"piyazon/ASR-cv-corpus-ug-10",
"piyazon/ASR-cv-corpus-ug-9",
"piyazon/ASR-cv-corpus-ug-8",
"piyazon/ASR-cv-corpus-ug-7",
]
# Global variables for processor and model
processor = None
model = None
current_model_id = None
def load_model(model_id: str, hf_token: str):
"""Load the selected Wav2Vec2 model and processor."""
global processor, model, current_model_id
try:
print(f"Loading model: {model_id}")
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = Wav2Vec2BertProcessor.from_pretrained(model_id, token=hf_token)
model = AutoModelForCTC.from_pretrained(model_id, token=hf_token).to(device)
current_model_id = model_id
print(f"Model loaded on {device}")
return True
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error loading model: {str(e)}")
def transcribe_speech(audio_bytes: bytes, model_id: str, hf_token: str) -> str:
"""
Transcribe audio bytes using the selected Wav2Vec2 model.
Args:
audio_bytes: Bytes of the audio file
model_id: Selected Wav2Vec2 model ID
hf_token: Hugging Face authentication token
Returns:
Transcribed text
"""
global processor, model, current_model_id
# Load model if not already loaded or if model selection changed
if processor is None or model is None or current_model_id != model_id:
load_model(model_id, hf_token)
try:
# Save audio bytes to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".webm") as temp_file:
temp_file.write(audio_bytes)
temp_file_path = temp_file.name
# Convert to WAV using pydub
try:
audio = AudioSegment.from_file(temp_file_path)
wav_io = io.BytesIO()
audio.export(wav_io, format="wav")
wav_io.seek(0)
finally:
os.unlink(temp_file_path) # Clean up temporary file
# Load audio from WAV bytes
# waveform, sample_rate = torchaudio.load(wav_io)
# Create an audio decoder instance
decoder = torchcodec.decoders.AudioDecoder(wav_io)
# Get all the audio samples using the correct method
audio_samples = decoder.get_all_samples()
# Get the waveform and sample rate from the AudioSamples object
waveform = audio_samples.data
sample_rate = audio_samples.sample_rate
print("Loaded audio shape:", waveform.shape, "sample rate:", sample_rate)
# Resample to 16kHz (required for Wav2Vec2)
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
sample_rate = 16000
# Ensure waveform is mono (single channel)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True) # Convert to mono
print("Processed waveform shape:", waveform.shape)
# Convert waveform to input features
processed = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt", padding=True)
input_features = processed["input_features"]
print("Input features shape:", input_features.shape)
# Move to device
input_features = input_features.to(model.device)
# Perform inference
with torch.no_grad():
logits = model(input_features).logits
# Get predicted token IDs
pred_ids = torch.argmax(logits, dim=-1)[0]
# Compute probabilities for confidence
log_probs = torch.log_softmax(logits, dim=-1)
probs = torch.exp(log_probs)
# Decode with word offsets
word_outputs = processor.decode(pred_ids, output_word_offsets=True)
transcription = word_outputs.text.strip()
print(transcription)
# Model stride: 320 samples at 16kHz = 20ms per frame (standard for Wav2Vec2; adjust if needed)
stride_samples = 320
frame_duration = stride_samples / sample_rate # 0.02 seconds
# Extract word-level details from word_offsets (convert np.int64 to int for JSON)
word_details = []
for word_info in word_outputs.word_offsets:
word = word_info['word']
start_frame = int(word_info['start_offset']) # Convert np.int64 to int
end_frame = int(word_info['end_offset']) # Convert np.int64 to int
# Convert frames to seconds
start_time = round(start_frame * frame_duration, 2)
end_time = round(end_frame * frame_duration, 2)
duration = round(end_time - start_time, 2)
# Compute confidence: Average max prob over frames in this word
word_frame_indices = range(start_frame, end_frame)
word_probs = [
probs[0, frame_idx, pred_ids[frame_idx]].item()
for frame_idx in word_frame_indices if frame_idx < len(pred_ids)
]
confidence = round(sum(word_probs) / len(word_probs), 3) if word_probs else 0.0
word_details.append({
'word': word,
'start_time': start_time,
'end_time': end_time,
'duration': duration,
'confidence': confidence
})
# Explicitly clean up tensors to free memory
del waveform, audio_samples, input_features, logits, pred_ids, log_probs, probs
torch.cuda.empty_cache() # Clear GPU memory cache if using GPU
return {
"transcription": transcription,
"word_details": word_details
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}")
# # Decode predictions
# # transcription = processor.decode(pred_ids)
# # Explicitly clean up tensors to free memory
# del waveform, audio_samples, input_features, logits, pred_ids
# torch.cuda.empty_cache() # Clear GPU memory cache if using GPU
# return transcription.strip()
# except Exception as e:
# raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}")
def punctuate_uyghur(transcription: str) -> str:
"""
Add punctuation to Uyghur transcription text using a multilingual model.
Args:
transcription (str): Unpunctuated Uyghur text (Arabic script).
Returns:
str: Punctuated text with Uyghur-specific punctuation marks.
"""
# Initialize the punctuation model
model = PunctuationModel()
# Restore punctuation using the model
punctuated = model.restore_punctuation(transcription.strip())
# Post-process to replace Latin punctuation with Uyghur-specific marks
punctuated = punctuated.replace(",", "،").replace("?", "؟")
return punctuated.strip()
@app.post("/transcribe", response_model=dict)
async def transcribe(
audio: UploadFile = File(..., description="Audio file (MP3, WAV, etc.)"),
model_id: str = Form(MODEL_OPTIONS[0], description="Wav2Vec2 model ID"),
hf_token: str = Form(..., description="Hugging Face authentication token")
):
"""
Transcribe Uyghur speech from an audio file.
- **audio**: The audio file to transcribe.
- **model_id**: The Hugging Face model ID (defaults to first option).
- **hf_token**: Hugging Face authentication token for accessing models.
Returns: JSON with 'transcription' field containing the Uyghur text.
"""
# Read audio file bytes
audio_bytes = await audio.read()
if len(audio_bytes) == 0:
raise HTTPException(status_code=400, detail="Empty audio file")
result = transcribe_speech(audio_bytes, model_id, hf_token)
return JSONResponse(content=result)
# transcription = transcribe_speech(audio_bytes, model_id, hf_token)
# return JSONResponse(content={"transcription": transcription})
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)