Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torchaudio | |
| import torchcodec | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.responses import JSONResponse, HTMLResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from transformers import Wav2Vec2BertProcessor, AutoModelForCTC | |
| from pydub import AudioSegment | |
| import tempfile | |
| import io | |
| from deepmultilingualpunctuation import PunctuationModel | |
| app = FastAPI(title="Uyghur Speech to Text API") | |
| # Allow specific domains or all (*) for testing | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # @app.get("/") | |
| # def greet_json(): | |
| # return {"Hello": "World!"} | |
| # @app.get("/") | |
| # def greet_json(): | |
| # return { | |
| # "URL: ": """<a href="https://transcriber.piyazon.top/">https://transcriber.piyazon.top/</a>""" | |
| # } | |
| def greet_html(): | |
| return """ | |
| <html> | |
| <body> | |
| <h1> | |
| URL1: | |
| <a href="https://asr.piyazon.top">https://asr.piyazon.top</a> | |
| </h1> | |
| <h1> | |
| URL2: | |
| <a href="https://transcriber.piyazon.top">https://transcriber.piyazon.top</a> | |
| </h1> | |
| </body> | |
| </html> | |
| """ | |
| # Available Wav2Vec2 models | |
| MODEL_OPTIONS = [ | |
| "piyazon/ASR-cv-corpus-ug-11", | |
| "piyazon/ASR-cv-corpus-ug-10", | |
| "piyazon/ASR-cv-corpus-ug-9", | |
| "piyazon/ASR-cv-corpus-ug-8", | |
| "piyazon/ASR-cv-corpus-ug-7", | |
| ] | |
| # Global variables for processor and model | |
| processor = None | |
| model = None | |
| current_model_id = None | |
| def load_model(model_id: str, hf_token: str): | |
| """Load the selected Wav2Vec2 model and processor.""" | |
| global processor, model, current_model_id | |
| try: | |
| print(f"Loading model: {model_id}") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| processor = Wav2Vec2BertProcessor.from_pretrained(model_id, token=hf_token) | |
| model = AutoModelForCTC.from_pretrained(model_id, token=hf_token).to(device) | |
| current_model_id = model_id | |
| print(f"Model loaded on {device}") | |
| return True | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error loading model: {str(e)}") | |
| def transcribe_speech(audio_bytes: bytes, model_id: str, hf_token: str) -> str: | |
| """ | |
| Transcribe audio bytes using the selected Wav2Vec2 model. | |
| Args: | |
| audio_bytes: Bytes of the audio file | |
| model_id: Selected Wav2Vec2 model ID | |
| hf_token: Hugging Face authentication token | |
| Returns: | |
| Transcribed text | |
| """ | |
| global processor, model, current_model_id | |
| # Load model if not already loaded or if model selection changed | |
| if processor is None or model is None or current_model_id != model_id: | |
| load_model(model_id, hf_token) | |
| try: | |
| # Save audio bytes to a temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".webm") as temp_file: | |
| temp_file.write(audio_bytes) | |
| temp_file_path = temp_file.name | |
| # Convert to WAV using pydub | |
| try: | |
| audio = AudioSegment.from_file(temp_file_path) | |
| wav_io = io.BytesIO() | |
| audio.export(wav_io, format="wav") | |
| wav_io.seek(0) | |
| finally: | |
| os.unlink(temp_file_path) # Clean up temporary file | |
| # Load audio from WAV bytes | |
| # waveform, sample_rate = torchaudio.load(wav_io) | |
| # Create an audio decoder instance | |
| decoder = torchcodec.decoders.AudioDecoder(wav_io) | |
| # Get all the audio samples using the correct method | |
| audio_samples = decoder.get_all_samples() | |
| # Get the waveform and sample rate from the AudioSamples object | |
| waveform = audio_samples.data | |
| sample_rate = audio_samples.sample_rate | |
| print("Loaded audio shape:", waveform.shape, "sample rate:", sample_rate) | |
| # Resample to 16kHz (required for Wav2Vec2) | |
| if sample_rate != 16000: | |
| resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) | |
| waveform = resampler(waveform) | |
| sample_rate = 16000 | |
| # Ensure waveform is mono (single channel) | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) # Convert to mono | |
| print("Processed waveform shape:", waveform.shape) | |
| # Convert waveform to input features | |
| processed = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt", padding=True) | |
| input_features = processed["input_features"] | |
| print("Input features shape:", input_features.shape) | |
| # Move to device | |
| input_features = input_features.to(model.device) | |
| # Perform inference | |
| with torch.no_grad(): | |
| logits = model(input_features).logits | |
| # Get predicted token IDs | |
| pred_ids = torch.argmax(logits, dim=-1)[0] | |
| # Compute probabilities for confidence | |
| log_probs = torch.log_softmax(logits, dim=-1) | |
| probs = torch.exp(log_probs) | |
| # Decode with word offsets | |
| word_outputs = processor.decode(pred_ids, output_word_offsets=True) | |
| transcription = word_outputs.text.strip() | |
| print(transcription) | |
| # Model stride: 320 samples at 16kHz = 20ms per frame (standard for Wav2Vec2; adjust if needed) | |
| stride_samples = 320 | |
| frame_duration = stride_samples / sample_rate # 0.02 seconds | |
| # Extract word-level details from word_offsets (convert np.int64 to int for JSON) | |
| word_details = [] | |
| for word_info in word_outputs.word_offsets: | |
| word = word_info['word'] | |
| start_frame = int(word_info['start_offset']) # Convert np.int64 to int | |
| end_frame = int(word_info['end_offset']) # Convert np.int64 to int | |
| # Convert frames to seconds | |
| start_time = round(start_frame * frame_duration, 2) | |
| end_time = round(end_frame * frame_duration, 2) | |
| duration = round(end_time - start_time, 2) | |
| # Compute confidence: Average max prob over frames in this word | |
| word_frame_indices = range(start_frame, end_frame) | |
| word_probs = [ | |
| probs[0, frame_idx, pred_ids[frame_idx]].item() | |
| for frame_idx in word_frame_indices if frame_idx < len(pred_ids) | |
| ] | |
| confidence = round(sum(word_probs) / len(word_probs), 3) if word_probs else 0.0 | |
| word_details.append({ | |
| 'word': word, | |
| 'start_time': start_time, | |
| 'end_time': end_time, | |
| 'duration': duration, | |
| 'confidence': confidence | |
| }) | |
| # Explicitly clean up tensors to free memory | |
| del waveform, audio_samples, input_features, logits, pred_ids, log_probs, probs | |
| torch.cuda.empty_cache() # Clear GPU memory cache if using GPU | |
| return { | |
| "transcription": transcription, | |
| "word_details": word_details | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}") | |
| # # Decode predictions | |
| # # transcription = processor.decode(pred_ids) | |
| # # Explicitly clean up tensors to free memory | |
| # del waveform, audio_samples, input_features, logits, pred_ids | |
| # torch.cuda.empty_cache() # Clear GPU memory cache if using GPU | |
| # return transcription.strip() | |
| # except Exception as e: | |
| # raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}") | |
| def punctuate_uyghur(transcription: str) -> str: | |
| """ | |
| Add punctuation to Uyghur transcription text using a multilingual model. | |
| Args: | |
| transcription (str): Unpunctuated Uyghur text (Arabic script). | |
| Returns: | |
| str: Punctuated text with Uyghur-specific punctuation marks. | |
| """ | |
| # Initialize the punctuation model | |
| model = PunctuationModel() | |
| # Restore punctuation using the model | |
| punctuated = model.restore_punctuation(transcription.strip()) | |
| # Post-process to replace Latin punctuation with Uyghur-specific marks | |
| punctuated = punctuated.replace(",", "،").replace("?", "؟") | |
| return punctuated.strip() | |
| async def transcribe( | |
| audio: UploadFile = File(..., description="Audio file (MP3, WAV, etc.)"), | |
| model_id: str = Form(MODEL_OPTIONS[0], description="Wav2Vec2 model ID"), | |
| hf_token: str = Form(..., description="Hugging Face authentication token") | |
| ): | |
| """ | |
| Transcribe Uyghur speech from an audio file. | |
| - **audio**: The audio file to transcribe. | |
| - **model_id**: The Hugging Face model ID (defaults to first option). | |
| - **hf_token**: Hugging Face authentication token for accessing models. | |
| Returns: JSON with 'transcription' field containing the Uyghur text. | |
| """ | |
| # Read audio file bytes | |
| audio_bytes = await audio.read() | |
| if len(audio_bytes) == 0: | |
| raise HTTPException(status_code=400, detail="Empty audio file") | |
| result = transcribe_speech(audio_bytes, model_id, hf_token) | |
| return JSONResponse(content=result) | |
| # transcription = transcribe_speech(audio_bytes, model_id, hf_token) | |
| # return JSONResponse(content={"transcription": transcription}) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |