m97j commited on
Commit
f2a7503
ยท
1 Parent(s): 17ee8b7

Initial commit

Browse files
app.py CHANGED
@@ -4,11 +4,10 @@ from fastapi.middleware.cors import CORSMiddleware
4
  from manager.dialogue_manager import handle_dialogue
5
  from rag.rag_manager import chroma_initialized, load_game_docs_from_disk, add_docs, set_embedder
6
  from contextlib import asynccontextmanager
7
- from models.model_loader import load_emotion_model, load_fallback_model, load_embedder
8
  from schemas import AskReq, AskRes
9
  from pathlib import Path
10
  from config import (
11
- EMOTION_MODEL_NAME, EMOTION_MODEL_DIR,
12
  FALLBACK_MODEL_NAME, FALLBACK_MODEL_DIR,
13
  EMBEDDER_MODEL_NAME, EMBEDDER_MODEL_DIR,
14
  HF_TOKEN, BASE_DIR
@@ -17,10 +16,7 @@ from config import (
17
 
18
  @asynccontextmanager
19
  async def lifespan(app: FastAPI):
20
- # Emotion
21
- emo_tokenizer, emo_model = load_emotion_model(EMOTION_MODEL_NAME, EMOTION_MODEL_DIR, token=HF_TOKEN)
22
- app.state.emotion_tokenizer = emo_tokenizer
23
- app.state.emotion_model = emo_model
24
 
25
  # Fallback
26
  fb_tokenizer, fb_model = load_fallback_model(FALLBACK_MODEL_NAME, FALLBACK_MODEL_DIR, token=HF_TOKEN)
 
4
  from manager.dialogue_manager import handle_dialogue
5
  from rag.rag_manager import chroma_initialized, load_game_docs_from_disk, add_docs, set_embedder
6
  from contextlib import asynccontextmanager
7
+ from models.model_loader import load_fallback_model, load_embedder
8
  from schemas import AskReq, AskRes
9
  from pathlib import Path
10
  from config import (
 
11
  FALLBACK_MODEL_NAME, FALLBACK_MODEL_DIR,
12
  EMBEDDER_MODEL_NAME, EMBEDDER_MODEL_DIR,
13
  HF_TOKEN, BASE_DIR
 
16
 
17
  @asynccontextmanager
18
  async def lifespan(app: FastAPI):
19
+ print("๐Ÿš€ ์„œ๋ฒ„ ์‹œ์ž‘ ์ค‘... ๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
 
 
 
20
 
21
  # Fallback
22
  fb_tokenizer, fb_model = load_fallback_model(FALLBACK_MODEL_NAME, FALLBACK_MODEL_DIR, token=HF_TOKEN)
config.py CHANGED
@@ -14,12 +14,10 @@ HF_TIMEOUT = float(os.getenv("HF_TIMEOUT", "25"))
14
 
15
 
16
  # ๋ชจ๋ธ ์ด๋ฆ„
17
- EMOTION_MODEL_NAME = os.getenv("EMOTION_MODEL_NAME", "tae898/emoberta-base-ko")
18
  FALLBACK_MODEL_NAME = os.getenv("FALLBACK_MODEL_NAME", "skt/ko-gpt-trinity-1.2B-v0.5")
19
  EMBEDDER_MODEL_NAME = os.getenv("EMBEDDER_MODEL_NAME", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
20
 
21
  # ๋ชจ๋ธ ๋””๋ ‰ํ† ๋ฆฌ
22
- EMOTION_MODEL_DIR = Path(os.getenv("EMOTION_MODEL_DIR", BASE_DIR / "models" / "emotion-classification-model"))
23
  FALLBACK_MODEL_DIR = Path(os.getenv("FALLBACK_MODEL_DIR", BASE_DIR / "models" / "fallback-npc-model"))
24
  EMBEDDER_MODEL_DIR = Path(os.getenv("EMBEDDER_MODEL_DIR", BASE_DIR / "models" / "sentence-embedder"))
25
 
 
14
 
15
 
16
  # ๋ชจ๋ธ ์ด๋ฆ„
 
17
  FALLBACK_MODEL_NAME = os.getenv("FALLBACK_MODEL_NAME", "skt/ko-gpt-trinity-1.2B-v0.5")
18
  EMBEDDER_MODEL_NAME = os.getenv("EMBEDDER_MODEL_NAME", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
19
 
20
  # ๋ชจ๋ธ ๋””๋ ‰ํ† ๋ฆฌ
 
21
  FALLBACK_MODEL_DIR = Path(os.getenv("FALLBACK_MODEL_DIR", BASE_DIR / "models" / "fallback-npc-model"))
22
  EMBEDDER_MODEL_DIR = Path(os.getenv("EMBEDDER_MODEL_DIR", BASE_DIR / "models" / "sentence-embedder"))
23
 
manager/dialogue_manager.py CHANGED
@@ -3,7 +3,7 @@ from pipeline.preprocess import preprocess_input
3
  from pipeline.generator import generate_response
4
  from pipeline.postprocess import postprocess_fallback, postprocess_main
5
  from models.fallback_model import generate_fallback_response
6
- from .prompt_builder import build_main_prompt, build_fallback_prompt # ์ˆ˜์ •๋œ prompt ๋นŒ๋” ์‚ฌ์šฉ
7
 
8
  async def handle_dialogue(
9
  request: Request,
 
3
  from pipeline.generator import generate_response
4
  from pipeline.postprocess import postprocess_fallback, postprocess_main
5
  from models.fallback_model import generate_fallback_response
6
+ from .prompt_builder import build_main_prompt, build_fallback_prompt
7
 
8
  async def handle_dialogue(
9
  request: Request,
models/model_loader.py CHANGED
@@ -1,24 +1,11 @@
1
  from pathlib import Path
2
  from transformers import (
3
  AutoTokenizer,
4
- AutoModelForSequenceClassification,
5
  AutoModelForCausalLM
6
  )
7
  from sentence_transformers import SentenceTransformer
8
 
9
 
10
- def load_emotion_model(model_name: str, model_dir: Path, token: str = None):
11
- if not (model_dir / "config.json").exists():
12
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_auth_token=token)
13
- model = AutoModelForSequenceClassification.from_pretrained(model_name, trust_remote_code=True, use_auth_token=token)
14
- tokenizer.save_pretrained(model_dir)
15
- model.save_pretrained(model_dir)
16
-
17
- tokenizer = AutoTokenizer.from_pretrained(str(model_dir), trust_remote_code=True, local_files_only=True)
18
- model = AutoModelForSequenceClassification.from_pretrained(str(model_dir), trust_remote_code=True, local_files_only=True)
19
- return tokenizer, model
20
-
21
-
22
  def load_fallback_model(model_name: str, model_dir: Path, token: str = None):
23
  if not (model_dir / "config.json").exists():
24
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_auth_token=token)
 
1
  from pathlib import Path
2
  from transformers import (
3
  AutoTokenizer,
 
4
  AutoModelForCausalLM
5
  )
6
  from sentence_transformers import SentenceTransformer
7
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def load_fallback_model(model_name: str, model_dir: Path, token: str = None):
10
  if not (model_dir / "config.json").exists():
11
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_auth_token=token)
pipeline/preprocess.py CHANGED
@@ -1,7 +1,6 @@
1
  import json, torch
2
  from fastapi import Request
3
  from manager.agent_manager import agent_manager
4
- from models.emotion_model import detect_emotion
5
  from models.fallback_model import generate_fallback_response
6
  from utils.context_parser import ContextParser
7
  from sentence_transformers import util
 
1
  import json, torch
2
  from fastapi import Request
3
  from manager.agent_manager import agent_manager
 
4
  from models.fallback_model import generate_fallback_response
5
  from utils.context_parser import ContextParser
6
  from sentence_transformers import util