SMILE1205 commited on
Commit
ee646ee
Β·
verified Β·
1 Parent(s): 1918e01

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +60 -0
main.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from langchain_huggingface import HuggingFaceEmbeddings
5
+ from langchain_chroma import Chroma
6
+ from huggingface_hub import snapshot_download
7
+ import os
8
+ import shutil
9
+
10
+ app = FastAPI()
11
+ app.add_middleware(
12
+ CORSMiddleware,
13
+ allow_origins=["*"],
14
+ allow_methods=["*"],
15
+ allow_headers=["*"],
16
+ )
17
+
18
+ DB_LOCAL_PATH = "./chroma_db"
19
+
20
+ if not os.path.exists(DB_LOCAL_PATH):
21
+ print("πŸ”„ HuggingFaceμ—μ„œ DB λ‹€μš΄λ‘œλ“œ 쀑...")
22
+ snapshot_download(
23
+ repo_id=os.environ["HF_REPO_ID"],
24
+ repo_type="dataset",
25
+ token=os.environ["HF_TOKEN"],
26
+ local_dir="./hf_data",
27
+ )
28
+ shutil.copytree("./hf_data/chroma_db", DB_LOCAL_PATH)
29
+ print("βœ… DB λ‹€μš΄λ‘œλ“œ μ™„λ£Œ")
30
+
31
+ print("πŸ”„ μž„λ² λ”© λͺ¨λΈ λ‘œλ”© 쀑...")
32
+ embeddings = HuggingFaceEmbeddings(model_name="jhgan/ko-sroberta-multitask")
33
+ db = Chroma(persist_directory=DB_LOCAL_PATH, embedding_function=embeddings)
34
+ print(f"βœ… DB λ‘œλ“œ μ™„λ£Œ β€” 청크 수: {db._collection.count()}")
35
+
36
+ class QueryRequest(BaseModel):
37
+ query: str
38
+ k: int = 5
39
+
40
+ @app.get("/")
41
+ def root():
42
+ return {"status": "EduMap RAG API 정상 μž‘λ™ 쀑", "chunks": db._collection.count()}
43
+
44
+ @app.get("/health")
45
+ def health():
46
+ return {"status": "ok", "chunks": db._collection.count()}
47
+
48
+ @app.post("/retrieve")
49
+ def retrieve(req: QueryRequest):
50
+ docs = db.similarity_search(req.query, k=req.k)
51
+ return {
52
+ "documents": [
53
+ {
54
+ "text": doc.page_content,
55
+ "source": doc.metadata.get("source", "μ•Œ 수 μ—†μŒ"),
56
+ "region": doc.metadata.get("region", "μ•Œ 수 μ—†μŒ"),
57
+ }
58
+ for doc in docs
59
+ ]
60
+ }