suhyun12 commited on
Commit
669fc71
·
1 Parent(s): 4b450fe

추론+main 한꺼번에 작성, kobart interference 수정, whisper_v1 사용

Browse files
Files changed (1) hide show
  1. app_v2/app_v2.py +138 -0
app_v2/app_v2.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ from fastapi.responses import JSONResponse
3
+ from transformers import (
4
+ PreTrainedTokenizerFast,
5
+ BartForConditionalGeneration,
6
+ WhisperProcessor,
7
+ WhisperForConditionalGeneration,
8
+ AutoConfig
9
+ )
10
+ import torch
11
+ import torchaudio
12
+ import os
13
+ import random
14
+
15
+ app = FastAPI()
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+
18
+ # whisper 로딩
19
+ whisper_model = WhisperForConditionalGeneration.from_pretrained(
20
+ "jeju-potato/jeju_potato_models", subfolder="whisper_v1_ep11"
21
+ ).to(device)
22
+ whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
23
+ whisper_model.config.forced_decoder_ids = None
24
+ whisper_model.generation_config.forced_decoder_ids = None
25
+
26
+ # first_kobart 로딩
27
+ first_tokenizer = PreTrainedTokenizerFast.from_pretrained("jeju-potato/jeju_potato_models", subfolder="first_kobart_finetuned_v3")
28
+ first_model = BartForConditionalGeneration.from_pretrained(
29
+ "jeju-potato/jeju_potato_models", subfolder="first_kobart_finetuned_v3"
30
+ ).to(device)
31
+ first_model.config.forced_decoder_ids = None
32
+ first_model.generation_config.forced_decoder_ids = None
33
+
34
+ # second_kobart 로딩 (경고 제거용 config 명시)
35
+ second_config = AutoConfig.from_pretrained(
36
+ "jeju-potato/jeju_potato_models", subfolder="second_kobart_finetuned_v2"
37
+ )
38
+ second_config.num_labels = 2
39
+ second_config.id2label = {0: "NEGATIVE", 1: "POSITIVE"}
40
+ second_config.label2id = {"NEGATIVE": 0, "POSITIVE": 1}
41
+
42
+ second_tokenizer = PreTrainedTokenizerFast.from_pretrained("jeju-potato/jeju_potato_models", subfolder="second_kobart_finetuned_v2")
43
+ second_model = BartForConditionalGeneration.from_pretrained(
44
+ "jeju-potato/jeju_potato_models",
45
+ subfolder="second_kobart_finetuned_v2",
46
+ config=second_config
47
+ ).to(device)
48
+ second_model.config.forced_decoder_ids = None
49
+ second_model.generation_config.forced_decoder_ids = None
50
+
51
+ # whisper 추론
52
+ def inference_whisper(audio_path):
53
+ waveform, sr = torchaudio.load(audio_path)
54
+ waveform = waveform.mean(dim=0) if waveform.shape[0] > 1 else waveform.squeeze(0)
55
+ inputs = whisper_processor(
56
+ waveform.numpy(),
57
+ sampling_rate=sr,
58
+ return_tensors="pt",
59
+ task="transcribe"
60
+ ).to(device)
61
+ whisper_model.generation_config.forced_decoder_ids = None
62
+ whisper_model.generation_config.suppress_tokens = []
63
+
64
+ with torch.no_grad():
65
+ pred_ids = whisper_model.generate(inputs["input_features"], max_length=128)
66
+ decoded = whisper_processor.batch_decode(pred_ids, skip_special_tokens=True)[0]
67
+ return decoded
68
+
69
+ # kobart 추론
70
+ def inference_kobart(text, model, tokenizer):
71
+ encoded = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
72
+ input_ids = encoded["input_ids"].to(device)
73
+ attention_mask = encoded["attention_mask"].to(device)
74
+
75
+ with torch.no_grad():
76
+ pred_ids = model.generate(
77
+ input_ids=input_ids,
78
+ attention_mask=attention_mask,
79
+ max_length=128
80
+ )
81
+ result = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)[0]
82
+ return result
83
+
84
+ # 전체 파이프라인
85
+ def run_pipeline(audio_path):
86
+ whisper_text = inference_whisper(audio_path)
87
+ first_kobart_text = inference_kobart(whisper_text, first_model, first_tokenizer)
88
+ second_kobart_text = inference_kobart(first_kobart_text, second_model, second_tokenizer)
89
+ return whisper_text, first_kobart_text, second_kobart_text
90
+
91
+ # FastAPI 엔드포인트
92
+ @app.get("/test_random")
93
+ def test_random_file():
94
+ # data/source_data에서 랜덤으포 파일 하나 선택
95
+ audio_dir = r"D:\whisper\audio1_all"
96
+
97
+ candidates = [f for f in os.listdir(audio_dir) if f.endswith(".wav")]
98
+ if not candidates:
99
+ return {"error": "No .wav files found in data/source_data"}
100
+
101
+ chosen_file = random.choice(candidates)
102
+ audio_path = os.path.join(audio_dir, chosen_file)
103
+
104
+ # 전체 파이프라인
105
+ try:
106
+ whisper_text, first_kobart_text, second_kobart_text = run_pipeline(audio_path)
107
+ except Exception as e:
108
+ return {"error": str(e)}
109
+
110
+ # 결과 반환
111
+ return {
112
+ "filename": chosen_file,
113
+ "whisper_result": whisper_text,
114
+ "first_kobart_result": first_kobart_text,
115
+ "second_kobart_result": second_kobart_text
116
+ }
117
+
118
+ # 클라이언트로부터 음성파일을 Form-data 형식으로 받음
119
+ @app.post("/inference")
120
+ async def inference(audio: UploadFile=File(...)):
121
+ # 파일 저장
122
+ os.makedirs("temp", exist_ok=True)
123
+ audio_path = f"temp/{audio.filename}" # 클라이언트로부터 받은 음성파일을 temp/ 폴더에 저장
124
+ with open(audio_path, "wb") as f:
125
+ f.write(await audio.read())
126
+
127
+ # 모델 추론
128
+ try:
129
+ whisper_text, first_kobart_text, second_kobart_text = run_pipeline(audio_path)
130
+ except Exception as e:
131
+ return JSONResponse(status_code=500, content={"error": str(e)})
132
+
133
+ # 결과 반환
134
+ return {
135
+ "whisper_result": whisper_text,
136
+ "first_kobart_result": first_kobart_text,
137
+ "second_kobart_result": second_kobart_text
138
+ }