Spaces:
Sleeping
Sleeping
| import os | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from typing import List, Dict, Optional | |
| import uvicorn | |
| import PIL.Image | |
| import io | |
| from services.classifier import MBTIClassifier | |
| from services.llm_service import LLMService | |
| from schemas import ( | |
| AnalyzeRequest, AnalyzeResponse, | |
| OCRResponse, | |
| ChatStartRequest, ChatStartResponse, | |
| ChatRequest, ChatResponse, | |
| SimulateRequest, SimulateResponse | |
| ) | |
| app = FastAPI(title="CommuniKate API") | |
| # CORS ์ค์ (Next.js ์ฐ๋์ ์ํด ํ์) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # ๊ฐ๋ฐ ์ค์๋ ๋ชจ๋ ํ์ฉ, ์ด์ ์์๋ ํ๋ก ํธ์๋ ๋๋ฉ์ธ์ผ๋ก ์ ํ ๊ถ์ฅ | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def root(): | |
| return {"status": "ok", "message": "CommuniKate API is running"} | |
| async def health(): | |
| return {"status": "healthy"} | |
| # ์ฑ๊ธํค ํจํด์ผ๋ก ์๋น์ค ์ด๊ธฐํ | |
| classifier = MBTIClassifier() | |
| llm_service = LLMService(model_name="gemma4:latest") | |
| async def analyze(request: AnalyzeRequest): | |
| if not request.target_text.strip(): | |
| raise HTTPException(status_code=400, detail="๋ฉ์์ง๋ฅผ ์ ๋ ฅํด์ฃผ์ธ์.") | |
| probs = {} | |
| axis_scores = {'I/E': 0, 'S/N': 0, 'T/F': 0, 'P/J': 0} | |
| analysis_summary = "" | |
| target_mbti = "" | |
| if request.target_mbti_input == "์๋ ๋ถ์ (AI)": | |
| if request.context_detail and request.context_detail.strip(): | |
| reasoning_result = await llm_service.analyze_mbti_with_reasoning(request.context_detail, request.target_text) | |
| target_mbti = "์ฌ๋ ์๋ ๋ถ์ ์ค" | |
| analysis_summary = f"### ๐ง AI ์ ๋ฐ ์ํฉ ๋ถ์ ๋ฆฌํฌํธ\n{reasoning_result}" | |
| else: | |
| analysis = classifier.predict(request.target_text) | |
| target_mbti = analysis["mbti"] | |
| confidence = analysis["confidence"] | |
| probs = analysis["probabilities"] | |
| analysis_summary = f"### ๐ฏ ๋ฉ์์ง ๋ถ์ ๊ฒฐ๊ณผ: {target_mbti}\n**์ ๋ขฐ๋: {confidence*100:.1f}%**" | |
| # ์ถ ์ ์ ๊ณ์ฐ | |
| for mbti, p in probs.items(): | |
| if mbti[0] == 'I': axis_scores['I/E'] += p | |
| if mbti[1] == 'S': axis_scores['S/N'] += p | |
| if mbti[2] == 'T': axis_scores['T/F'] += p | |
| if mbti[3] == 'P': axis_scores['P/J'] += p | |
| else: | |
| target_mbti = request.target_mbti_input | |
| analysis_summary = f"### ๐ค ์ง์ ๋ MBTI: {target_mbti}\n**์ฌ์ฉ์ ์ง์ ์ค์ **" | |
| # ๋ต๋ณ ์ ์ ์์ฑ | |
| advice = await llm_service.generate_response( | |
| request.my_mbti, target_mbti, request.situation, | |
| request.relationship, request.vibe, request.target_text | |
| ) | |
| # ๋ฐ์ดํฐ ๊ธฐ๋ฐ ๋ถ์ ๊ทผ๊ฑฐ ์ถ๊ฐ | |
| if request.target_mbti_input == "์๋ ๋ถ์ (AI)": | |
| axis_data = ", ".join([f"{k}: {v*100:.1f}%" for k, v in axis_scores.items()]) | |
| reasoning_result = await llm_service.analyze_mbti_with_reasoning( | |
| f"์ํฉ: {request.situation}, ๊ด๊ณ: {request.relationship}, ๋ถ์๊ธฐ: {request.vibe}\n[๋ฉ์์ง ๋ถ์ ๋ฐ์ดํฐ] {axis_data}", | |
| request.target_text | |
| ) | |
| analysis_summary += f"\n\n--- \n#### ๐ก๏ธ AI ์ ๋ฌธ๊ฐ์ ์ฑํฅ ๋ถ์ ๊ฐ์ด๋\n{reasoning_result}" | |
| return AnalyzeResponse( | |
| analysis_summary=analysis_summary, | |
| probabilities=probs, | |
| axis_scores=axis_scores, | |
| advice=advice | |
| ) | |
| async def ocr(file: UploadFile = File(...)): | |
| contents = await file.read() | |
| image = PIL.Image.open(io.BytesIO(contents)) | |
| text = await llm_service.extract_text_from_image(image) | |
| return OCRResponse(text=text) | |
| async def chat_start(request: ChatStartRequest): | |
| history = [] | |
| coaching_tip = "๋ํ๋ฅผ ์์ํ์ต๋๋ค. ๋ฉ์์ง๋ฅผ ๋ณด๋ด์๋ฉด AI ์ฝ์นญ์ด ์์๋ฉ๋๋ค." | |
| if request.ai_first: | |
| greeting = await llm_service.generate_initial_greeting( | |
| request.target_mbti, request.relationship, request.situation | |
| ) | |
| history.append({"role": "assistant", "content": greeting}) | |
| coaching_tip = "AI๊ฐ ๋จผ์ ์ธ์ฌ๋ฅผ ๊ฑด๋ธ์ต๋๋ค. ๋ํ๋ฅผ ์ด์ด๊ฐ ๋ณด์ธ์!" | |
| return ChatStartResponse(history=history, coaching_tip=coaching_tip) | |
| async def chat(request: ChatRequest): | |
| if not request.user_input.strip(): | |
| return ChatResponse(history=[h.dict() for h in request.history], coaching_tip="๋ฉ์์ง๋ฅผ ์ ๋ ฅํด ์ฃผ์ธ์.") | |
| import asyncio | |
| history_dicts = [h.dict() for h in request.history] | |
| tasks = [ | |
| llm_service.chat_with_persona( | |
| history_dicts, request.user_input, request.user_mbti, | |
| request.target_mbti, request.relationship, request.situation | |
| ), | |
| llm_service.get_coaching_tip(request.user_input, request.target_mbti, request.relationship) | |
| ] | |
| response, coaching_tip = await asyncio.gather(*tasks) | |
| new_history = history_dicts + [ | |
| {"role": "user", "content": request.user_input}, | |
| {"role": "assistant", "content": response} | |
| ] | |
| return ChatResponse(history=new_history, coaching_tip=coaching_tip) | |
| async def simulate(request: SimulateRequest): | |
| reaction = await llm_service.simulate_reaction( | |
| request.my_mbti, request.target_mbti_input, | |
| request.situation, request.relationship, request.advice_text | |
| ) | |
| return SimulateResponse(reaction=reaction) | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 8000)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |