File size: 3,598 Bytes
80de494
7cf0dc9
80de494
 
 
 
 
cbbf00c
80de494
7cf0dc9
80de494
 
7cf0dc9
80de494
7cf0dc9
80de494
 
7cf0dc9
 
 
80de494
 
 
 
7cf0dc9
80de494
 
 
 
 
 
 
7cf0dc9
80de494
7cf0dc9
 
80de494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7cf0dc9
 
 
80de494
0eba9b9
 
 
 
80de494
0eba9b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80de494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7cf0dc9
 
 
80de494
 
 
7cf0dc9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
import yaml
from model.llama3 import LLaMA3 

# ---------------- ์„ค์ • ๋กœ๋“œ ----------------
with open("config/llama3.yaml", "r") as f:
    config = yaml.safe_load(f)

# ---------------- ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ----------------
llama3 = LLaMA3(config)

# ---------------- FastAPI ์•ฑ ----------------
app = FastAPI(
    title="Korean Pronunciation Correction API",
    description="FastAPI + LLaMA3 ๊ธฐ๋ฐ˜ ๋ฐœ์Œ ๊ต์ • ์„œ๋ฒ„",
    version="1.0.0"
)

# ---------------- ์ž…๋ ฅ ๋ชจ๋ธ ----------------
class InputData(BaseModel):
    user_input: str
    correct_input: str

# ---------------- API: JSON POST ----------------
@app.post("/generate")
async def generate_correction(data: InputData):
    result = llama3.generate(data.user_input, data.correct_input)
    return {"result": result}

# ---------------- HTML UI ----------------
@app.get("/", response_class=HTMLResponse)
async def form_ui():
    return """
    <html>
        <head>
            <title>Korean Pronunciation Correction</title>
        </head>
        <body style="font-family: sans-serif; max-width: 600px; margin: auto; padding: 2rem;">
            <h1>๐Ÿ—ฃ๏ธ ๋ฐœ์Œ ๊ต์ • ํ…Œ์ŠคํŠธ</h1>
            <form action="/submit" method="post">
                <label for="user_input">๐Ÿง ์ž˜๋ชป๋œ ๋ฐœ์Œ:</label><br>
                <input type="text" id="user_input" name="user_input" value="๋ฐ•๋ผ" required><br><br>
                
                <label for="correct_input">๐ŸŽฏ ์˜ฌ๋ฐ”๋ฅธ ๋ฐœ์Œ:</label><br>
                <input type="text" id="correct_input" name="correct_input" value="๋ฐœ๋ผ" required><br><br>
                
                <input type="submit" value="๊ต์ • ์‹คํ–‰">
            </form>
        </body>
    </html>
    """

# ---------------- ๊ฒฐ๊ณผ ๋ Œ๋”๋ง ----------------
from fastapi.responses import HTMLResponse
from fastapi import Form, Request
import traceback

@app.post("/submit", response_class=HTMLResponse)
async def handle_form(request: Request, user_input: str = Form(...), correct_input: str = Form(...)):
    try:
        result = llama3.generate(user_input, correct_input)
    except Exception as e:
        error_details = traceback.format_exc()
        return f"""
        <html>
            <head><title>์—๋Ÿฌ</title></head>
            <body style="font-family: sans-serif; max-width: 600px; margin: auto; padding: 2rem;">
                <h1>โŒ ์„œ๋ฒ„ ์˜ค๋ฅ˜ ๋ฐœ์ƒ</h1>
                <p><strong>์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€:</strong></p>
                <pre>{str(e)}</pre>
                <hr>
                <p><strong>์—๋Ÿฌ ์ƒ์„ธ:</strong></p>
                <pre>{error_details}</pre>
                <br>
                <a href="/">๋Œ์•„๊ฐ€๊ธฐ</a>
            </body>
        </html>
        """
    
    return f"""
    <html>
        <head><title>๊ฒฐ๊ณผ</title></head>
        <body style="font-family: sans-serif; max-width: 600px; margin: auto; padding: 2rem;">
            <h1>โœ… ๊ฒฐ๊ณผ</h1>
            <p><strong>์ž…๋ ฅ๋œ ๋ฐœ์Œ:</strong> {user_input}</p>
            <p><strong>์ •๋‹ต ๋ฐœ์Œ:</strong> {correct_input}</p>
            <hr>
            <h2>๐Ÿง  ๋ชจ๋ธ ์‘๋‹ต:</h2>
            <pre>{result}</pre>
            <br>
            <a href="/">๋‹ค์‹œ ์‹œ๋„ํ•˜๊ธฐ</a>
        </body>
    </html>
    """
    
# ---------------- ํ—ฌ์Šค ์ฒดํฌ ----------------
@app.get("/health")
async def health_check():
    return {
        "status": "ok",
        "model": config["model"]["id"],
        "device": config["model"]["device"]
    }