from fastapi import FastAPI, File, Form, HTTPException import datetime import time import torch from typing import Optional import os import numpy as np from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM, AutoConfig from huggingface_hub import hf_hub_download from fuzzywuzzy import fuzz from utils import ffmpeg_read, query_dummy, query_raw, find_different from pitch_scoring import extract_pitch_pattern, compute_accent_score ## config API_TOKEN = os.environ["API_TOKEN"] MODEL_PATH = os.environ["MODEL_PATH"] PITCH_PATH = os.environ["PITCH_PATH"] QUANTIZED_MODEL_PATH = hf_hub_download(repo_id=MODEL_PATH, filename='quantized_model.pt', token=API_TOKEN) QUANTIZED_PITCH_MODEL_PATH = hf_hub_download(repo_id=PITCH_PATH, filename='quantized_model.pt', token=API_TOKEN) ## word preprocessor processor_with_lm = Wav2Vec2ProcessorWithLM.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN) processor = Wav2Vec2Processor.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN) ### quantized model config = AutoConfig.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN) dummy_model = Wav2Vec2ForCTC(config) quantized_model = torch.quantization.quantize_dynamic(dummy_model, {torch.nn.Linear}, dtype=torch.qint8, inplace=True) quantized_model.load_state_dict(torch.load(QUANTIZED_MODEL_PATH)) ## pitch preprocessor processor_pitch = Wav2Vec2Processor.from_pretrained(PITCH_PATH, use_auth_token=API_TOKEN) ### quantized pitch mode config = AutoConfig.from_pretrained(PITCH_PATH, use_auth_token=API_TOKEN) dummy_pitch_model = Wav2Vec2ForCTC(config) quantized_pitch_model = torch.quantization.quantize_dynamic(dummy_pitch_model, {torch.nn.Linear}, dtype=torch.qint8, inplace=True) quantized_pitch_model.load_state_dict(torch.load(QUANTIZED_PITCH_MODEL_PATH)) app = FastAPI() @app.get("/") def read_root(): return {"Message": "Application startup complete"} @app.post("/naomi_api_score/") async def predict( file: bytes = File(...), word: str = Form(...), pitch: Optional[str] = Form(None), temperature: int = Form(...), ): """ Transform input audio, get text and pitch from Huggingface api and calculate score by Levenshtein Distance Score Parameters: ---------- file : bytes input audio file word : strings true hiragana word to calculate word score pitch : strings true pitch to calculate pitch score temperature: integer the difficulty of AI model Returns: ------- timestamp: strings current time Year-Month-Day-Hours:Minutes:Second running_time : strings running time second error message : strings error message from api audio duration: integer durations of source audio target : integer durations of target audio method : string method applied to transform source audio word predict : strings text from api pitch predict : strings pitch from api wrong word index: strings (ex: 100) wrong word compare to target word wrong pitch index: strings (ex: 100) wrong word compare to target word score: integer Levenshtein Distance Score from pitch and word """ upload_audio = ffmpeg_read(file, sampling_rate=16000) audio_duration = len(upload_audio) / 16000 current_time = datetime.datetime.now().strftime("%Y-%h-%d-%H:%M:%S") start_time = time.time() error_message, score, word_preds, pitch_preds = None, None, None, None word_preds = query_raw(upload_audio, word, processor, processor_with_lm, quantized_model, temperature=temperature) if pitch is not None: if len(word) != len(pitch): error_message = "Length of word and pitch input is not equal" pitch_preds = query_dummy(upload_audio, processor_pitch, quantized_pitch_model) # find best word word_score_list = [] for word_predict in word_preds: word_score_list.append(fuzz.ratio(word, word_predict[0])) word_score = max(word_score_list) best_word_predict = word_preds[word_score_list.index(word_score)][0] wrong_word = find_different(word, best_word_predict) # get wrong word # find best pitch if pitch_preds is not None: best_pitch_predict = pitch_preds.replace(" ", "") if len(best_pitch_predict) < len(best_word_predict): best_pitch_predict = best_pitch_predict + "1" * (len(best_word_predict) - len(best_pitch_predict)) else: best_pitch_predict = best_pitch_predict[:len(best_word_predict)] # truncate to max len pitch_score = fuzz.ratio(pitch, best_pitch_predict) score = int((word_score * 2 + pitch_score) / 3) wrong_pitch = find_different(pitch, best_pitch_predict) # get wrong pitch else: score = int(word_score) best_pitch_predict = None wrong_pitch = None return {"timestamp": current_time, "running_time": f"{round(time.time() - start_time, 4)} s", "error_message": error_message, "audio_duration": audio_duration, "word_predict": best_word_predict, "pitch_predict": best_pitch_predict, "wrong_word_index": wrong_word, "wrong_pitch_index": wrong_pitch, "score": score, } @app.post("/pitch_score/") async def pitch_score( file: bytes = File(...), expected_pattern: str = Form(...), mora_count: int = Form(...), ): """Score pitch accent by comparing detected HL pattern to expected pattern. Parameters ---------- file : bytes WAV audio file (16kHz mono recommended). expected_pattern : str Expected HL pattern string, e.g. "LHH". mora_count : int Number of moras in the word. Returns ------- detected_pattern : str HL pattern detected from audio, e.g. "LHH". accent_score : int Similarity score 0-100 vs expected pattern. """ if mora_count <= 0: raise HTTPException(status_code=422, detail="mora_count must be > 0") # Decode audio bytes via ffmpeg (handles WAV, MP3, etc.) try: audio = ffmpeg_read(file, sampling_rate=16000) except Exception as e: raise HTTPException(status_code=422, detail=f"Audio decode failed: {e}") detected_pattern, _ = extract_pitch_pattern(audio, mora_count=mora_count, sr=16000) if not detected_pattern: # Silent or unvoiced audio — return 0 score gracefully return {"detected_pattern": "", "accent_score": 0} score = compute_accent_score(detected_pattern, expected_pattern) return {"detected_pattern": detected_pattern, "accent_score": score}