llama-server / app.py
bigeco's picture
Update app.py
0eba9b9 verified
from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
import yaml
from model.llama3 import LLaMA3
# ---------------- ์„ค์ • ๋กœ๋“œ ----------------
with open("config/llama3.yaml", "r") as f:
config = yaml.safe_load(f)
# ---------------- ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ----------------
llama3 = LLaMA3(config)
# ---------------- FastAPI ์•ฑ ----------------
app = FastAPI(
title="Korean Pronunciation Correction API",
description="FastAPI + LLaMA3 ๊ธฐ๋ฐ˜ ๋ฐœ์Œ ๊ต์ • ์„œ๋ฒ„",
version="1.0.0"
)
# ---------------- ์ž…๋ ฅ ๋ชจ๋ธ ----------------
class InputData(BaseModel):
user_input: str
correct_input: str
# ---------------- API: JSON POST ----------------
@app.post("/generate")
async def generate_correction(data: InputData):
result = llama3.generate(data.user_input, data.correct_input)
return {"result": result}
# ---------------- HTML UI ----------------
@app.get("/", response_class=HTMLResponse)
async def form_ui():
return """
<html>
<head>
<title>Korean Pronunciation Correction</title>
</head>
<body style="font-family: sans-serif; max-width: 600px; margin: auto; padding: 2rem;">
<h1>๐Ÿ—ฃ๏ธ ๋ฐœ์Œ ๊ต์ • ํ…Œ์ŠคํŠธ</h1>
<form action="/submit" method="post">
<label for="user_input">๐Ÿง ์ž˜๋ชป๋œ ๋ฐœ์Œ:</label><br>
<input type="text" id="user_input" name="user_input" value="๋ฐ•๋ผ" required><br><br>
<label for="correct_input">๐ŸŽฏ ์˜ฌ๋ฐ”๋ฅธ ๋ฐœ์Œ:</label><br>
<input type="text" id="correct_input" name="correct_input" value="๋ฐœ๋ผ" required><br><br>
<input type="submit" value="๊ต์ • ์‹คํ–‰">
</form>
</body>
</html>
"""
# ---------------- ๊ฒฐ๊ณผ ๋ Œ๋”๋ง ----------------
from fastapi.responses import HTMLResponse
from fastapi import Form, Request
import traceback
@app.post("/submit", response_class=HTMLResponse)
async def handle_form(request: Request, user_input: str = Form(...), correct_input: str = Form(...)):
try:
result = llama3.generate(user_input, correct_input)
except Exception as e:
error_details = traceback.format_exc()
return f"""
<html>
<head><title>์—๋Ÿฌ</title></head>
<body style="font-family: sans-serif; max-width: 600px; margin: auto; padding: 2rem;">
<h1>โŒ ์„œ๋ฒ„ ์˜ค๋ฅ˜ ๋ฐœ์ƒ</h1>
<p><strong>์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€:</strong></p>
<pre>{str(e)}</pre>
<hr>
<p><strong>์—๋Ÿฌ ์ƒ์„ธ:</strong></p>
<pre>{error_details}</pre>
<br>
<a href="/">๋Œ์•„๊ฐ€๊ธฐ</a>
</body>
</html>
"""
return f"""
<html>
<head><title>๊ฒฐ๊ณผ</title></head>
<body style="font-family: sans-serif; max-width: 600px; margin: auto; padding: 2rem;">
<h1>โœ… ๊ฒฐ๊ณผ</h1>
<p><strong>์ž…๋ ฅ๋œ ๋ฐœ์Œ:</strong> {user_input}</p>
<p><strong>์ •๋‹ต ๋ฐœ์Œ:</strong> {correct_input}</p>
<hr>
<h2>๐Ÿง  ๋ชจ๋ธ ์‘๋‹ต:</h2>
<pre>{result}</pre>
<br>
<a href="/">๋‹ค์‹œ ์‹œ๋„ํ•˜๊ธฐ</a>
</body>
</html>
"""
# ---------------- ํ—ฌ์Šค ์ฒดํฌ ----------------
@app.get("/health")
async def health_check():
return {
"status": "ok",
"model": config["model"]["id"],
"device": config["model"]["device"]
}