communiKate / backend /main.py
ashfortune
๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์ˆ˜์ •
97207e4
import os
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Dict, Optional
import uvicorn
import PIL.Image
import io
from services.classifier import MBTIClassifier
from services.llm_service import LLMService
from schemas import (
AnalyzeRequest, AnalyzeResponse,
OCRResponse,
ChatStartRequest, ChatStartResponse,
ChatRequest, ChatResponse,
SimulateRequest, SimulateResponse
)
app = FastAPI(title="CommuniKate API")
# CORS ์„ค์ • (Next.js ์—ฐ๋™์„ ์œ„ํ•ด ํ•„์š”)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # ๊ฐœ๋ฐœ ์ค‘์—๋Š” ๋ชจ๋‘ ํ—ˆ์šฉ, ์šด์˜ ์‹œ์—๋Š” ํ”„๋ก ํŠธ์—”๋“œ ๋„๋ฉ”์ธ์œผ๋กœ ์ œํ•œ ๊ถŒ์žฅ
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
return {"status": "ok", "message": "CommuniKate API is running"}
@app.get("/api/health")
async def health():
return {"status": "healthy"}
# ์‹ฑ๊ธ€ํ†ค ํŒจํ„ด์œผ๋กœ ์„œ๋น„์Šค ์ดˆ๊ธฐํ™”
classifier = MBTIClassifier()
llm_service = LLMService(model_name="gemma4:latest")
@app.post("/api/analyze", response_model=AnalyzeResponse)
async def analyze(request: AnalyzeRequest):
if not request.target_text.strip():
raise HTTPException(status_code=400, detail="๋ฉ”์‹œ์ง€๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”.")
probs = {}
axis_scores = {'I/E': 0, 'S/N': 0, 'T/F': 0, 'P/J': 0}
analysis_summary = ""
target_mbti = ""
if request.target_mbti_input == "์ž๋™ ๋ถ„์„ (AI)":
if request.context_detail and request.context_detail.strip():
reasoning_result = await llm_service.analyze_mbti_with_reasoning(request.context_detail, request.target_text)
target_mbti = "์‹ฌ๋„ ์žˆ๋Š” ๋ถ„์„ ์ค‘"
analysis_summary = f"### ๐Ÿง  AI ์ •๋ฐ€ ์ƒํ™ฉ ๋ถ„์„ ๋ฆฌํฌํŠธ\n{reasoning_result}"
else:
analysis = classifier.predict(request.target_text)
target_mbti = analysis["mbti"]
confidence = analysis["confidence"]
probs = analysis["probabilities"]
analysis_summary = f"### ๐ŸŽฏ ๋ฉ”์‹œ์ง€ ๋ถ„์„ ๊ฒฐ๊ณผ: {target_mbti}\n**์‹ ๋ขฐ๋„: {confidence*100:.1f}%**"
# ์ถ• ์ ์ˆ˜ ๊ณ„์‚ฐ
for mbti, p in probs.items():
if mbti[0] == 'I': axis_scores['I/E'] += p
if mbti[1] == 'S': axis_scores['S/N'] += p
if mbti[2] == 'T': axis_scores['T/F'] += p
if mbti[3] == 'P': axis_scores['P/J'] += p
else:
target_mbti = request.target_mbti_input
analysis_summary = f"### ๐Ÿ‘ค ์ง€์ •๋œ MBTI: {target_mbti}\n**์‚ฌ์šฉ์ž ์ง์ ‘ ์„ค์ •**"
# ๋‹ต๋ณ€ ์ œ์•ˆ ์ƒ์„ฑ
advice = await llm_service.generate_response(
request.my_mbti, target_mbti, request.situation,
request.relationship, request.vibe, request.target_text
)
# ๋ฐ์ดํ„ฐ ๊ธฐ๋ฐ˜ ๋ถ„์„ ๊ทผ๊ฑฐ ์ถ”๊ฐ€
if request.target_mbti_input == "์ž๋™ ๋ถ„์„ (AI)":
axis_data = ", ".join([f"{k}: {v*100:.1f}%" for k, v in axis_scores.items()])
reasoning_result = await llm_service.analyze_mbti_with_reasoning(
f"์ƒํ™ฉ: {request.situation}, ๊ด€๊ณ„: {request.relationship}, ๋ถ„์œ„๊ธฐ: {request.vibe}\n[๋ฉ”์‹œ์ง€ ๋ถ„์„ ๋ฐ์ดํ„ฐ] {axis_data}",
request.target_text
)
analysis_summary += f"\n\n--- \n#### ๐Ÿ›ก๏ธ AI ์ „๋ฌธ๊ฐ€์˜ ์„ฑํ–ฅ ๋ถ„์„ ๊ฐ€์ด๋“œ\n{reasoning_result}"
return AnalyzeResponse(
analysis_summary=analysis_summary,
probabilities=probs,
axis_scores=axis_scores,
advice=advice
)
@app.post("/api/ocr", response_model=OCRResponse)
async def ocr(file: UploadFile = File(...)):
contents = await file.read()
image = PIL.Image.open(io.BytesIO(contents))
text = await llm_service.extract_text_from_image(image)
return OCRResponse(text=text)
@app.post("/api/chat/start", response_model=ChatStartResponse)
async def chat_start(request: ChatStartRequest):
history = []
coaching_tip = "๋Œ€ํ™”๋ฅผ ์‹œ์ž‘ํ–ˆ์Šต๋‹ˆ๋‹ค. ๋ฉ”์‹œ์ง€๋ฅผ ๋ณด๋‚ด์‹œ๋ฉด AI ์ฝ”์นญ์ด ์‹œ์ž‘๋ฉ๋‹ˆ๋‹ค."
if request.ai_first:
greeting = await llm_service.generate_initial_greeting(
request.target_mbti, request.relationship, request.situation
)
history.append({"role": "assistant", "content": greeting})
coaching_tip = "AI๊ฐ€ ๋จผ์ € ์ธ์‚ฌ๋ฅผ ๊ฑด๋„ธ์Šต๋‹ˆ๋‹ค. ๋Œ€ํ™”๋ฅผ ์ด์–ด๊ฐ€ ๋ณด์„ธ์š”!"
return ChatStartResponse(history=history, coaching_tip=coaching_tip)
@app.post("/api/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
if not request.user_input.strip():
return ChatResponse(history=[h.dict() for h in request.history], coaching_tip="๋ฉ”์‹œ์ง€๋ฅผ ์ž…๋ ฅํ•ด ์ฃผ์„ธ์š”.")
import asyncio
history_dicts = [h.dict() for h in request.history]
tasks = [
llm_service.chat_with_persona(
history_dicts, request.user_input, request.user_mbti,
request.target_mbti, request.relationship, request.situation
),
llm_service.get_coaching_tip(request.user_input, request.target_mbti, request.relationship)
]
response, coaching_tip = await asyncio.gather(*tasks)
new_history = history_dicts + [
{"role": "user", "content": request.user_input},
{"role": "assistant", "content": response}
]
return ChatResponse(history=new_history, coaching_tip=coaching_tip)
@app.post("/api/simulate", response_model=SimulateResponse)
async def simulate(request: SimulateRequest):
reaction = await llm_service.simulate_reaction(
request.my_mbti, request.target_mbti_input,
request.situation, request.relationship, request.advice_text
)
return SimulateResponse(reaction=reaction)
if __name__ == "__main__":
port = int(os.environ.get("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)