File size: 6,841 Bytes
e931e45
f787cd1
4c7e2b8
f787cd1
61fe542
241ce22
f787cd1
8732281
f787cd1
 
 
 
e931e45
f787cd1
 
 
 
 
07d2b90
447ae70
 
f787cd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aba3792
 
 
61fe542
aba3792
 
4c7e2b8
 
61fe542
4c7e2b8
 
61fe542
 
 
 
 
4c7e2b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24c223a
4c7e2b8
 
 
61fe542
24c223a
4c7e2b8
24c223a
61fe542
 
4c7e2b8
24c223a
4c7e2b8
 
 
 
 
 
 
 
 
24c223a
 
 
 
 
 
 
 
 
4c7e2b8
 
 
 
29b492f
4c7e2b8
 
702e8c4
 
 
 
 
 
8732281
e931e45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from fastapi import FastAPI, File, Form, HTTPException
import datetime
import time
import torch
from typing import Optional

import os
import numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM, AutoConfig
from huggingface_hub import hf_hub_download
from fuzzywuzzy import fuzz
from utils import ffmpeg_read, query_dummy, query_raw, find_different
from pitch_scoring import extract_pitch_pattern, compute_accent_score

## config
API_TOKEN = os.environ["API_TOKEN"]
MODEL_PATH = os.environ["MODEL_PATH"]
PITCH_PATH = os.environ["PITCH_PATH"]

QUANTIZED_MODEL_PATH = hf_hub_download(repo_id=MODEL_PATH, filename='quantized_model.pt', token=API_TOKEN)
QUANTIZED_PITCH_MODEL_PATH = hf_hub_download(repo_id=PITCH_PATH, filename='quantized_model.pt', token=API_TOKEN)


## word preprocessor
processor_with_lm = Wav2Vec2ProcessorWithLM.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN)
processor = Wav2Vec2Processor.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN)

### quantized model
config = AutoConfig.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN)
dummy_model = Wav2Vec2ForCTC(config)
quantized_model = torch.quantization.quantize_dynamic(dummy_model, {torch.nn.Linear}, dtype=torch.qint8, inplace=True)
quantized_model.load_state_dict(torch.load(QUANTIZED_MODEL_PATH))

## pitch preprocessor 
processor_pitch = Wav2Vec2Processor.from_pretrained(PITCH_PATH, use_auth_token=API_TOKEN)

### quantized pitch mode
config = AutoConfig.from_pretrained(PITCH_PATH, use_auth_token=API_TOKEN)
dummy_pitch_model = Wav2Vec2ForCTC(config)
quantized_pitch_model = torch.quantization.quantize_dynamic(dummy_pitch_model, {torch.nn.Linear}, dtype=torch.qint8, inplace=True)
quantized_pitch_model.load_state_dict(torch.load(QUANTIZED_PITCH_MODEL_PATH))

app = FastAPI()


@app.get("/")
def read_root():
    return {"Message": "Application startup complete"}


@app.post("/naomi_api_score/")
async def predict(
        file: bytes = File(...),
        word: str = Form(...),
        pitch: Optional[str] = Form(None),
        temperature: int = Form(...),
):
    """ Transform input audio, get text and pitch from Huggingface api and calculate score by Levenshtein Distance Score
        Parameters:
         ----------
        file : bytes
            input audio file
        word : strings
            true hiragana word to calculate word score
        pitch : strings
            true pitch to calculate pitch score
        temperature: integer
            the difficulty of AI model
        Returns:
        -------
        timestamp: strings
            current time Year-Month-Day-Hours:Minutes:Second
        running_time : strings
            running time second
        error message : strings
            error message from api
        audio duration: integer
            durations of source audio
        target : integer
            durations of target audio
        method : string
            method applied to transform source audio
        word predict : strings
            text from api
        pitch predict : strings
            pitch from api
        wrong word index: strings (ex: 100)
            wrong word compare to target word
        wrong pitch index: strings (ex: 100)
            wrong word compare to target word
        score: integer
            Levenshtein Distance Score from pitch and word
    """
    upload_audio = ffmpeg_read(file, sampling_rate=16000)
    audio_duration = len(upload_audio) / 16000
    current_time = datetime.datetime.now().strftime("%Y-%h-%d-%H:%M:%S")
    start_time = time.time()
    error_message, score, word_preds, pitch_preds = None, None, None, None

    word_preds = query_raw(upload_audio, word, processor, processor_with_lm, quantized_model, temperature=temperature)
    if pitch is not None:
        if len(word) != len(pitch):
            error_message = "Length of word and pitch input is not equal"
        pitch_preds = query_dummy(upload_audio, processor_pitch, quantized_pitch_model)

    # find best word
    word_score_list = []
    for word_predict in word_preds:
        word_score_list.append(fuzz.ratio(word, word_predict[0]))
    word_score = max(word_score_list)
    best_word_predict = word_preds[word_score_list.index(word_score)][0]
    wrong_word = find_different(word, best_word_predict)  # get wrong word

    # find best pitch
    if pitch_preds is not None:
        best_pitch_predict = pitch_preds.replace(" ", "")
        if len(best_pitch_predict) < len(best_word_predict):
            best_pitch_predict = best_pitch_predict + "1" * (len(best_word_predict) - len(best_pitch_predict))
        else:
            best_pitch_predict = best_pitch_predict[:len(best_word_predict)]  # truncate to max len
        pitch_score = fuzz.ratio(pitch, best_pitch_predict)
        score = int((word_score * 2 + pitch_score) / 3)
        wrong_pitch = find_different(pitch, best_pitch_predict)  # get wrong pitch
    else:
        score = int(word_score)
        best_pitch_predict = None
        wrong_pitch = None

    return {"timestamp": current_time,
            "running_time": f"{round(time.time() - start_time, 4)} s",
            "error_message": error_message,
            "audio_duration": audio_duration,
            "word_predict": best_word_predict,
            "pitch_predict": best_pitch_predict,
            "wrong_word_index": wrong_word,
            "wrong_pitch_index": wrong_pitch,
            "score": score,
            }


@app.post("/pitch_score/")
async def pitch_score(
    file: bytes = File(...),
    expected_pattern: str = Form(...),
    mora_count: int = Form(...),
):
    """Score pitch accent by comparing detected HL pattern to expected pattern.

    Parameters
    ----------
    file : bytes
        WAV audio file (16kHz mono recommended).
    expected_pattern : str
        Expected HL pattern string, e.g. "LHH".
    mora_count : int
        Number of moras in the word.

    Returns
    -------
    detected_pattern : str
        HL pattern detected from audio, e.g. "LHH".
    accent_score : int
        Similarity score 0-100 vs expected pattern.
    """
    if mora_count <= 0:
        raise HTTPException(status_code=422, detail="mora_count must be > 0")

    # Decode audio bytes via ffmpeg (handles WAV, MP3, etc.)
    try:
        audio = ffmpeg_read(file, sampling_rate=16000)
    except Exception as e:
        raise HTTPException(status_code=422, detail=f"Audio decode failed: {e}")

    detected_pattern, _ = extract_pitch_pattern(audio, mora_count=mora_count, sr=16000)

    if not detected_pattern:
        # Silent or unvoiced audio — return 0 score gracefully
        return {"detected_pattern": "", "accent_score": 0}

    score = compute_accent_score(detected_pattern, expected_pattern)
    return {"detected_pattern": detected_pattern, "accent_score": score}