naomi-app-api / main.py
Vu Minh Chien
feat: add /pitch_score/ endpoint with librosa F0 HL pattern analysis
e931e45
from fastapi import FastAPI, File, Form, HTTPException
import datetime
import time
import torch
from typing import Optional
import os
import numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM, AutoConfig
from huggingface_hub import hf_hub_download
from fuzzywuzzy import fuzz
from utils import ffmpeg_read, query_dummy, query_raw, find_different
from pitch_scoring import extract_pitch_pattern, compute_accent_score
## config
API_TOKEN = os.environ["API_TOKEN"]
MODEL_PATH = os.environ["MODEL_PATH"]
PITCH_PATH = os.environ["PITCH_PATH"]
QUANTIZED_MODEL_PATH = hf_hub_download(repo_id=MODEL_PATH, filename='quantized_model.pt', token=API_TOKEN)
QUANTIZED_PITCH_MODEL_PATH = hf_hub_download(repo_id=PITCH_PATH, filename='quantized_model.pt', token=API_TOKEN)
## word preprocessor
processor_with_lm = Wav2Vec2ProcessorWithLM.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN)
processor = Wav2Vec2Processor.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN)
### quantized model
config = AutoConfig.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN)
dummy_model = Wav2Vec2ForCTC(config)
quantized_model = torch.quantization.quantize_dynamic(dummy_model, {torch.nn.Linear}, dtype=torch.qint8, inplace=True)
quantized_model.load_state_dict(torch.load(QUANTIZED_MODEL_PATH))
## pitch preprocessor
processor_pitch = Wav2Vec2Processor.from_pretrained(PITCH_PATH, use_auth_token=API_TOKEN)
### quantized pitch mode
config = AutoConfig.from_pretrained(PITCH_PATH, use_auth_token=API_TOKEN)
dummy_pitch_model = Wav2Vec2ForCTC(config)
quantized_pitch_model = torch.quantization.quantize_dynamic(dummy_pitch_model, {torch.nn.Linear}, dtype=torch.qint8, inplace=True)
quantized_pitch_model.load_state_dict(torch.load(QUANTIZED_PITCH_MODEL_PATH))
app = FastAPI()
@app.get("/")
def read_root():
return {"Message": "Application startup complete"}
@app.post("/naomi_api_score/")
async def predict(
file: bytes = File(...),
word: str = Form(...),
pitch: Optional[str] = Form(None),
temperature: int = Form(...),
):
""" Transform input audio, get text and pitch from Huggingface api and calculate score by Levenshtein Distance Score
Parameters:
----------
file : bytes
input audio file
word : strings
true hiragana word to calculate word score
pitch : strings
true pitch to calculate pitch score
temperature: integer
the difficulty of AI model
Returns:
-------
timestamp: strings
current time Year-Month-Day-Hours:Minutes:Second
running_time : strings
running time second
error message : strings
error message from api
audio duration: integer
durations of source audio
target : integer
durations of target audio
method : string
method applied to transform source audio
word predict : strings
text from api
pitch predict : strings
pitch from api
wrong word index: strings (ex: 100)
wrong word compare to target word
wrong pitch index: strings (ex: 100)
wrong word compare to target word
score: integer
Levenshtein Distance Score from pitch and word
"""
upload_audio = ffmpeg_read(file, sampling_rate=16000)
audio_duration = len(upload_audio) / 16000
current_time = datetime.datetime.now().strftime("%Y-%h-%d-%H:%M:%S")
start_time = time.time()
error_message, score, word_preds, pitch_preds = None, None, None, None
word_preds = query_raw(upload_audio, word, processor, processor_with_lm, quantized_model, temperature=temperature)
if pitch is not None:
if len(word) != len(pitch):
error_message = "Length of word and pitch input is not equal"
pitch_preds = query_dummy(upload_audio, processor_pitch, quantized_pitch_model)
# find best word
word_score_list = []
for word_predict in word_preds:
word_score_list.append(fuzz.ratio(word, word_predict[0]))
word_score = max(word_score_list)
best_word_predict = word_preds[word_score_list.index(word_score)][0]
wrong_word = find_different(word, best_word_predict) # get wrong word
# find best pitch
if pitch_preds is not None:
best_pitch_predict = pitch_preds.replace(" ", "")
if len(best_pitch_predict) < len(best_word_predict):
best_pitch_predict = best_pitch_predict + "1" * (len(best_word_predict) - len(best_pitch_predict))
else:
best_pitch_predict = best_pitch_predict[:len(best_word_predict)] # truncate to max len
pitch_score = fuzz.ratio(pitch, best_pitch_predict)
score = int((word_score * 2 + pitch_score) / 3)
wrong_pitch = find_different(pitch, best_pitch_predict) # get wrong pitch
else:
score = int(word_score)
best_pitch_predict = None
wrong_pitch = None
return {"timestamp": current_time,
"running_time": f"{round(time.time() - start_time, 4)} s",
"error_message": error_message,
"audio_duration": audio_duration,
"word_predict": best_word_predict,
"pitch_predict": best_pitch_predict,
"wrong_word_index": wrong_word,
"wrong_pitch_index": wrong_pitch,
"score": score,
}
@app.post("/pitch_score/")
async def pitch_score(
file: bytes = File(...),
expected_pattern: str = Form(...),
mora_count: int = Form(...),
):
"""Score pitch accent by comparing detected HL pattern to expected pattern.
Parameters
----------
file : bytes
WAV audio file (16kHz mono recommended).
expected_pattern : str
Expected HL pattern string, e.g. "LHH".
mora_count : int
Number of moras in the word.
Returns
-------
detected_pattern : str
HL pattern detected from audio, e.g. "LHH".
accent_score : int
Similarity score 0-100 vs expected pattern.
"""
if mora_count <= 0:
raise HTTPException(status_code=422, detail="mora_count must be > 0")
# Decode audio bytes via ffmpeg (handles WAV, MP3, etc.)
try:
audio = ffmpeg_read(file, sampling_rate=16000)
except Exception as e:
raise HTTPException(status_code=422, detail=f"Audio decode failed: {e}")
detected_pattern, _ = extract_pitch_pattern(audio, mora_count=mora_count, sr=16000)
if not detected_pattern:
# Silent or unvoiced audio — return 0 score gracefully
return {"detected_pattern": "", "accent_score": 0}
score = compute_accent_score(detected_pattern, expected_pattern)
return {"detected_pattern": detected_pattern, "accent_score": score}