| import os |
| import io |
| import base64 |
| import torch |
| import numpy as np |
| import soundfile as sf |
| import librosa |
| from fastapi import FastAPI, HTTPException, UploadFile, File |
| from fastapi.responses import StreamingResponse |
| from pydantic import BaseModel |
| from transformers import ( |
| pipeline, |
| AutoTokenizer, |
| AutoModelForSeq2SeqLM, |
| Wav2Vec2ForCTC, |
| Wav2Vec2Processor, |
| SpeechT5Processor, |
| SpeechT5ForTextToSpeech, |
| SpeechT5HifiGan |
| ) |
| from peft import PeftModel |
| from datasets import load_dataset |
|
|
| CACHE_DIR = "/app/cache" |
| os.makedirs(CACHE_DIR, exist_ok=True) |
|
|
| |
| |
| os.environ["HF_HOME"] = CACHE_DIR |
| os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR |
| os.environ["HF_DATASETS_CACHE"] = CACHE_DIR |
| os.environ["HUGGINGFACE_HUB_CACHE"] = CACHE_DIR |
| os.environ["XDG_CACHE_HOME"] = CACHE_DIR |
|
|
| |
| app = FastAPI() |
|
|
| |
|
|
| lora_models = { |
| "tagalog-to-chabacano": "carlynamazed24/marian-mt-tl-cb-lora-finetuned", |
| "chabacano-to-tagalog": "carlynamazed24/marian-mt-cb-tl-lora-finetuned", |
| "english-to-chabacano": "carlynamazed24/marian-mt-en-cb-lora-finetuned", |
| "chabacano-to-english": "carlynamazed24/marian-mt-cb-en-lora-finetuned", |
| } |
|
|
| base_models_map = { |
| "tagalog-to-chabacano": "Helsinki-NLP/opus-mt-tl-es", |
| "chabacano-to-tagalog": "Helsinki-NLP/opus-mt-es-tl", |
| "english-to-chabacano": "Helsinki-NLP/opus-mt-en-es", |
| "chabacano-to-english": "Helsinki-NLP/opus-mt-es-en", |
| } |
|
|
| |
| stt_model_id = "carlynamazed24/wav2vec2-chabacano-stt-v2" |
| tts_model_id = "carlynamazed24/speecht5-chabacano-tts-v2" |
|
|
| |
| translation_pipelines = {} |
|
|
| |
| stt_processor = None |
| stt_model = None |
| tts_processor = None |
| tts_model = None |
| tts_vocoder = None |
| speaker_embeddings = None |
|
|
| @app.on_event("startup") |
| async def load_models(): |
| """ Load translation models at startup (Base + LoRA). """ |
| global translation_pipelines |
| global stt_processor, stt_model |
| global tts_processor, tts_model, tts_vocoder, speaker_embeddings |
| |
| |
| for key, lora_id in lora_models.items(): |
| try: |
| print(f"🔄 Loading system: {key}") |
| |
| base_model_id = base_models_map.get(key) |
| if not base_model_id: |
| print(f"⚠️ No base model defined for {key}, skipping.") |
| continue |
|
|
| print(f" 1. Loading Base Model: {base_model_id}") |
| |
| base_model = AutoModelForSeq2SeqLM.from_pretrained( |
| base_model_id, |
| cache_dir=CACHE_DIR |
| ) |
| |
| print(f" 2. Loading LoRA Adapter: {lora_id}") |
| peft_model = PeftModel.from_pretrained( |
| base_model, |
| lora_id, |
| cache_dir=CACHE_DIR |
| ) |
| |
| model = peft_model.merge_and_unload() |
| |
| print(f" 3. Loading Tokenizer...") |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(lora_id, cache_dir=CACHE_DIR) |
| except: |
| print(" (Using base tokenizer fallback)") |
| tokenizer = AutoTokenizer.from_pretrained(base_model_id, cache_dir=CACHE_DIR) |
|
|
| translation_pipelines[key] = pipeline("translation", model=model, tokenizer=tokenizer) |
| print(f"✅ Model {key} loaded and merged successfully!") |
| |
| except Exception as e: |
| print(f"❌ Error loading model {key}: {e}") |
| |
| |
| try: |
| print(f"🔄 Loading STT Model: {stt_model_id}") |
| |
| |
| _stt_processor = Wav2Vec2Processor.from_pretrained( |
| stt_model_id, |
| subfolder="best_model", |
| cache_dir=CACHE_DIR |
| ) |
| _stt_model = Wav2Vec2ForCTC.from_pretrained( |
| stt_model_id, |
| subfolder="best_model", |
| cache_dir=CACHE_DIR |
| ) |
| |
| stt_processor = _stt_processor |
| stt_model = _stt_model |
| print(f"✅ STT Model loaded successfully!") |
| except Exception as e: |
| import traceback |
| print(f"❌ Error loading STT model: {e}") |
| traceback.print_exc() |
| |
| |
| try: |
| print(f"🔄 Loading TTS Model: {tts_model_id}") |
| _tts_processor = SpeechT5Processor.from_pretrained(tts_model_id, cache_dir=CACHE_DIR) |
| _tts_model = SpeechT5ForTextToSpeech.from_pretrained(tts_model_id, cache_dir=CACHE_DIR) |
| _tts_vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan", cache_dir=CACHE_DIR) |
| |
| tts_processor = _tts_processor |
| tts_model = _tts_model |
| tts_vocoder = _tts_vocoder |
| |
| |
| try: |
| embeddings_dataset = load_dataset( |
| "Matthijs/cmu-arctic-xvectors", |
| split="validation", |
| cache_dir=CACHE_DIR, |
| trust_remote_code=True |
| ) |
| speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0) |
| except Exception as emb_error: |
| print(f"⚠️ Could not load speaker embeddings from dataset: {emb_error}") |
| print(" Using default speaker embeddings...") |
| |
| speaker_embeddings = torch.zeros(1, 512) |
| |
| print(f"✅ TTS Model loaded successfully!") |
| except Exception as e: |
| print(f"❌ Error loading TTS model: {e}") |
|
|
| class TranslationRequest(BaseModel): |
| text: str |
| model: str |
|
|
| @app.post("/translate/") |
| def translate(request: TranslationRequest): |
| """ Translate text using the selected model. """ |
| if request.model not in translation_pipelines: |
| raise HTTPException(status_code=400, detail=f"Model '{request.model}' not found or not loaded. Available models: {list(translation_pipelines.keys())}") |
| |
| try: |
| translator = translation_pipelines[request.model] |
| |
| result = translator(request.text)[0]["translation_text"] |
| return {"err": None, "translation": result} |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Translation error: {str(e)}") |
|
|
| @app.get("/status/") |
| def status(): |
| """ Check API status and loaded models. """ |
| return { |
| "status": "online", |
| "available_models": list(translation_pipelines.keys()), |
| "model_count": len(translation_pipelines) |
| } |
|
|
| @app.get("/models/") |
| def list_models(): |
| """ List all available models and their loading status. """ |
| loaded_models = set(translation_pipelines.keys()) |
| all_models = set(lora_models.keys()) |
| |
| return { |
| "loaded_models": list(loaded_models), |
| "pending_models": list(all_models - loaded_models), |
| "total_available": len(all_models) |
| } |
|
|
|
|
| |
| class STTRequest(BaseModel): |
| audio_base64: str |
|
|
| @app.post("/stt/") |
| async def speech_to_text(request: STTRequest): |
| """ Convert speech to text using Chabacano STT model. """ |
| if stt_model is None or stt_processor is None: |
| raise HTTPException(status_code=503, detail="STT model not loaded") |
| |
| try: |
| |
| audio_bytes = base64.b64decode(request.audio_base64) |
| |
| |
| audio_buffer = io.BytesIO(audio_bytes) |
| speech_array, sampling_rate = sf.read(audio_buffer) |
| |
| |
| if sampling_rate != 16000: |
| speech_array = librosa.resample(speech_array, orig_sr=sampling_rate, target_sr=16000) |
| sampling_rate = 16000 |
| |
| |
| input_values = stt_processor( |
| speech_array, |
| sampling_rate=sampling_rate, |
| return_tensors="pt" |
| ).input_values |
| |
| |
| with torch.no_grad(): |
| logits = stt_model(input_values).logits |
| |
| |
| predicted_ids = torch.argmax(logits, dim=-1) |
| transcription = stt_processor.batch_decode(predicted_ids)[0] |
| |
| return {"err": None, "transcription": transcription} |
| |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"STT error: {str(e)}") |
|
|
|
|
| @app.post("/stt/upload/") |
| async def speech_to_text_upload(file: UploadFile = File(...)): |
| """ Convert speech to text from uploaded audio file. """ |
| if stt_model is None or stt_processor is None: |
| raise HTTPException(status_code=503, detail="STT model not loaded") |
| |
| try: |
| |
| audio_bytes = await file.read() |
| audio_buffer = io.BytesIO(audio_bytes) |
| speech_array, sampling_rate = sf.read(audio_buffer) |
| |
| |
| if sampling_rate != 16000: |
| speech_array = librosa.resample(speech_array, orig_sr=sampling_rate, target_sr=16000) |
| sampling_rate = 16000 |
| |
| |
| input_values = stt_processor( |
| speech_array, |
| sampling_rate=sampling_rate, |
| return_tensors="pt" |
| ).input_values |
| |
| |
| with torch.no_grad(): |
| logits = stt_model(input_values).logits |
| |
| |
| predicted_ids = torch.argmax(logits, dim=-1) |
| transcription = stt_processor.batch_decode(predicted_ids)[0] |
| |
| return {"err": None, "transcription": transcription} |
| |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"STT error: {str(e)}") |
|
|
|
|
| |
| class TTSRequest(BaseModel): |
| text: str |
|
|
| @app.post("/tts/") |
| async def text_to_speech(request: TTSRequest): |
| """ Convert text to speech using Chabacano TTS model. """ |
| if tts_model is None or tts_processor is None or tts_vocoder is None: |
| raise HTTPException(status_code=503, detail="TTS model not loaded") |
| |
| try: |
| |
| inputs = tts_processor(text=request.text, return_tensors="pt") |
| |
| |
| with torch.no_grad(): |
| speech = tts_model.generate_speech( |
| inputs["input_ids"], |
| speaker_embeddings, |
| vocoder=tts_vocoder |
| ) |
| |
| |
| speech_numpy = speech.numpy() |
| |
| |
| audio_buffer = io.BytesIO() |
| sf.write(audio_buffer, speech_numpy, samplerate=16000, format='WAV') |
| audio_buffer.seek(0) |
| |
| |
| return StreamingResponse( |
| audio_buffer, |
| media_type="audio/wav", |
| headers={"Content-Disposition": "attachment; filename=speech.wav"} |
| ) |
| |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"TTS error: {str(e)}") |
|
|
|
|
| @app.post("/tts/base64/") |
| async def text_to_speech_base64(request: TTSRequest): |
| """ Convert text to speech and return as base64 encoded audio. """ |
| if tts_model is None or tts_processor is None or tts_vocoder is None: |
| raise HTTPException(status_code=503, detail="TTS model not loaded") |
| |
| try: |
| |
| inputs = tts_processor(text=request.text, return_tensors="pt") |
| |
| |
| with torch.no_grad(): |
| speech = tts_model.generate_speech( |
| inputs["input_ids"], |
| speaker_embeddings, |
| vocoder=tts_vocoder |
| ) |
| |
| |
| speech_numpy = speech.numpy() |
| |
| |
| audio_buffer = io.BytesIO() |
| sf.write(audio_buffer, speech_numpy, samplerate=16000, format='WAV') |
| audio_buffer.seek(0) |
| |
| |
| audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8') |
| |
| return {"err": None, "audio_base64": audio_base64} |
| |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"TTS error: {str(e)}") |