carlynamazed24's picture
Update app.py
a5149d2 verified
import os
import io
import base64
import torch
import numpy as np
import soundfile as sf
import librosa
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForSeq2SeqLM,
Wav2Vec2ForCTC,
Wav2Vec2Processor,
SpeechT5Processor,
SpeechT5ForTextToSpeech,
SpeechT5HifiGan
)
from peft import PeftModel
from datasets import load_dataset
CACHE_DIR = "/app/cache"
os.makedirs(CACHE_DIR, exist_ok=True)
# Set environment variables
#
os.environ["HF_HOME"] = CACHE_DIR
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
os.environ["HUGGINGFACE_HUB_CACHE"] = CACHE_DIR
os.environ["XDG_CACHE_HOME"] = CACHE_DIR
# Initialize FastAPI
app = FastAPI()
# --- CONFIGURATION NG MGA MODELS ---
lora_models = {
"tagalog-to-chabacano": "carlynamazed24/marian-mt-tl-cb-lora-finetuned",
"chabacano-to-tagalog": "carlynamazed24/marian-mt-cb-tl-lora-finetuned",
"english-to-chabacano": "carlynamazed24/marian-mt-en-cb-lora-finetuned",
"chabacano-to-english": "carlynamazed24/marian-mt-cb-en-lora-finetuned",
}
base_models_map = {
"tagalog-to-chabacano": "Helsinki-NLP/opus-mt-tl-es", # TL -> ES (Cb)
"chabacano-to-tagalog": "Helsinki-NLP/opus-mt-es-tl", # ES (Cb) -> TL
"english-to-chabacano": "Helsinki-NLP/opus-mt-en-es", # EN -> ES (Cb)
"chabacano-to-english": "Helsinki-NLP/opus-mt-es-en", # ES (Cb) -> EN
}
# --- STT & TTS MODELS ---
stt_model_id = "carlynamazed24/wav2vec2-chabacano-stt-v2"
tts_model_id = "carlynamazed24/speecht5-chabacano-tts-v2"
# Load translation pipelines
translation_pipelines = {}
# STT & TTS models (will be loaded at startup)
stt_processor = None
stt_model = None
tts_processor = None
tts_model = None
tts_vocoder = None
speaker_embeddings = None
@app.on_event("startup")
async def load_models():
""" Load translation models at startup (Base + LoRA). """
global translation_pipelines
global stt_processor, stt_model
global tts_processor, tts_model, tts_vocoder, speaker_embeddings
# --- Load Translation Models ---
for key, lora_id in lora_models.items():
try:
print(f"🔄 Loading system: {key}")
base_model_id = base_models_map.get(key)
if not base_model_id:
print(f"⚠️ No base model defined for {key}, skipping.")
continue
print(f" 1. Loading Base Model: {base_model_id}")
# Load Base Model
base_model = AutoModelForSeq2SeqLM.from_pretrained(
base_model_id,
cache_dir=CACHE_DIR
)
print(f" 2. Loading LoRA Adapter: {lora_id}")
peft_model = PeftModel.from_pretrained(
base_model,
lora_id,
cache_dir=CACHE_DIR
)
model = peft_model.merge_and_unload()
print(f" 3. Loading Tokenizer...")
try:
tokenizer = AutoTokenizer.from_pretrained(lora_id, cache_dir=CACHE_DIR)
except:
print(" (Using base tokenizer fallback)")
tokenizer = AutoTokenizer.from_pretrained(base_model_id, cache_dir=CACHE_DIR)
translation_pipelines[key] = pipeline("translation", model=model, tokenizer=tokenizer)
print(f"✅ Model {key} loaded and merged successfully!")
except Exception as e:
print(f"❌ Error loading model {key}: {e}")
# --- Load STT Model (Wav2Vec2) ---
try:
print(f"🔄 Loading STT Model: {stt_model_id}")
# Load processor and model from best_model subfolder which has complete files
_stt_processor = Wav2Vec2Processor.from_pretrained(
stt_model_id,
subfolder="best_model",
cache_dir=CACHE_DIR
)
_stt_model = Wav2Vec2ForCTC.from_pretrained(
stt_model_id,
subfolder="best_model",
cache_dir=CACHE_DIR
)
stt_processor = _stt_processor
stt_model = _stt_model
print(f"✅ STT Model loaded successfully!")
except Exception as e:
import traceback
print(f"❌ Error loading STT model: {e}")
traceback.print_exc()
# --- Load TTS Model (SpeechT5) ---
try:
print(f"🔄 Loading TTS Model: {tts_model_id}")
_tts_processor = SpeechT5Processor.from_pretrained(tts_model_id, cache_dir=CACHE_DIR)
_tts_model = SpeechT5ForTextToSpeech.from_pretrained(tts_model_id, cache_dir=CACHE_DIR)
_tts_vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan", cache_dir=CACHE_DIR)
tts_processor = _tts_processor
tts_model = _tts_model
tts_vocoder = _tts_vocoder
# Load speaker embeddings from CMU Arctic dataset (with trust_remote_code for newer datasets versions)
try:
embeddings_dataset = load_dataset(
"Matthijs/cmu-arctic-xvectors",
split="validation",
cache_dir=CACHE_DIR,
trust_remote_code=True
)
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
except Exception as emb_error:
print(f"⚠️ Could not load speaker embeddings from dataset: {emb_error}")
print(" Using default speaker embeddings...")
# Create a default speaker embedding (512-dimensional vector)
speaker_embeddings = torch.zeros(1, 512)
print(f"✅ TTS Model loaded successfully!")
except Exception as e:
print(f"❌ Error loading TTS model: {e}")
class TranslationRequest(BaseModel):
text: str
model: str
@app.post("/translate/")
def translate(request: TranslationRequest):
""" Translate text using the selected model. """
if request.model not in translation_pipelines:
raise HTTPException(status_code=400, detail=f"Model '{request.model}' not found or not loaded. Available models: {list(translation_pipelines.keys())}")
try:
translator = translation_pipelines[request.model]
# Generate translation
result = translator(request.text)[0]["translation_text"]
return {"err": None, "translation": result}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Translation error: {str(e)}")
@app.get("/status/")
def status():
""" Check API status and loaded models. """
return {
"status": "online",
"available_models": list(translation_pipelines.keys()),
"model_count": len(translation_pipelines)
}
@app.get("/models/")
def list_models():
""" List all available models and their loading status. """
loaded_models = set(translation_pipelines.keys())
all_models = set(lora_models.keys())
return {
"loaded_models": list(loaded_models),
"pending_models": list(all_models - loaded_models),
"total_available": len(all_models)
}
# --- STT ENDPOINT ---
class STTRequest(BaseModel):
audio_base64: str # Base64 encoded audio data
@app.post("/stt/")
async def speech_to_text(request: STTRequest):
""" Convert speech to text using Chabacano STT model. """
if stt_model is None or stt_processor is None:
raise HTTPException(status_code=503, detail="STT model not loaded")
try:
# Decode base64 audio
audio_bytes = base64.b64decode(request.audio_base64)
# Read audio file
audio_buffer = io.BytesIO(audio_bytes)
speech_array, sampling_rate = sf.read(audio_buffer)
# Resample to 16kHz if needed
if sampling_rate != 16000:
speech_array = librosa.resample(speech_array, orig_sr=sampling_rate, target_sr=16000)
sampling_rate = 16000
# Process audio
input_values = stt_processor(
speech_array,
sampling_rate=sampling_rate,
return_tensors="pt"
).input_values
# Perform inference
with torch.no_grad():
logits = stt_model(input_values).logits
# Decode the prediction
predicted_ids = torch.argmax(logits, dim=-1)
transcription = stt_processor.batch_decode(predicted_ids)[0]
return {"err": None, "transcription": transcription}
except Exception as e:
raise HTTPException(status_code=500, detail=f"STT error: {str(e)}")
@app.post("/stt/upload/")
async def speech_to_text_upload(file: UploadFile = File(...)):
""" Convert speech to text from uploaded audio file. """
if stt_model is None or stt_processor is None:
raise HTTPException(status_code=503, detail="STT model not loaded")
try:
# Read uploaded file
audio_bytes = await file.read()
audio_buffer = io.BytesIO(audio_bytes)
speech_array, sampling_rate = sf.read(audio_buffer)
# Resample to 16kHz if needed
if sampling_rate != 16000:
speech_array = librosa.resample(speech_array, orig_sr=sampling_rate, target_sr=16000)
sampling_rate = 16000
# Process audio
input_values = stt_processor(
speech_array,
sampling_rate=sampling_rate,
return_tensors="pt"
).input_values
# Perform inference
with torch.no_grad():
logits = stt_model(input_values).logits
# Decode the prediction
predicted_ids = torch.argmax(logits, dim=-1)
transcription = stt_processor.batch_decode(predicted_ids)[0]
return {"err": None, "transcription": transcription}
except Exception as e:
raise HTTPException(status_code=500, detail=f"STT error: {str(e)}")
# --- TTS ENDPOINT ---
class TTSRequest(BaseModel):
text: str
@app.post("/tts/")
async def text_to_speech(request: TTSRequest):
""" Convert text to speech using Chabacano TTS model. """
if tts_model is None or tts_processor is None or tts_vocoder is None:
raise HTTPException(status_code=503, detail="TTS model not loaded")
try:
# Process text
inputs = tts_processor(text=request.text, return_tensors="pt")
# Generate speech
with torch.no_grad():
speech = tts_model.generate_speech(
inputs["input_ids"],
speaker_embeddings,
vocoder=tts_vocoder
)
# Convert to numpy array and then to bytes
speech_numpy = speech.numpy()
# Create audio buffer
audio_buffer = io.BytesIO()
sf.write(audio_buffer, speech_numpy, samplerate=16000, format='WAV')
audio_buffer.seek(0)
# Return as streaming response
return StreamingResponse(
audio_buffer,
media_type="audio/wav",
headers={"Content-Disposition": "attachment; filename=speech.wav"}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"TTS error: {str(e)}")
@app.post("/tts/base64/")
async def text_to_speech_base64(request: TTSRequest):
""" Convert text to speech and return as base64 encoded audio. """
if tts_model is None or tts_processor is None or tts_vocoder is None:
raise HTTPException(status_code=503, detail="TTS model not loaded")
try:
# Process text
inputs = tts_processor(text=request.text, return_tensors="pt")
# Generate speech
with torch.no_grad():
speech = tts_model.generate_speech(
inputs["input_ids"],
speaker_embeddings,
vocoder=tts_vocoder
)
# Convert to numpy array
speech_numpy = speech.numpy()
# Create audio buffer
audio_buffer = io.BytesIO()
sf.write(audio_buffer, speech_numpy, samplerate=16000, format='WAV')
audio_buffer.seek(0)
# Encode to base64
audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8')
return {"err": None, "audio_base64": audio_base64}
except Exception as e:
raise HTTPException(status_code=500, detail=f"TTS error: {str(e)}")