ChatbotRAG / main.py
minhvtt's picture
Upload 6 files
6c982a7 verified
raw
history blame
9.99 kB
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing import Optional, List
from PIL import Image
import io
import numpy as np
from embedding_service import JinaClipEmbeddingService
from qdrant_service import QdrantVectorService
# Initialize FastAPI app
app = FastAPI(
title="Event Social Media Embeddings API",
description="API để embeddings và search text + images từ events & social media với Jina CLIP v2 + Qdrant",
version="1.0.0"
)
# Initialize services
print("Initializing services...")
embedding_service = JinaClipEmbeddingService(model_path="jinaai/jina-clip-v2")
qdrant_service = QdrantVectorService(
# URL và API key sẽ lấy từ environment variables
collection_name="event_social_media",
vector_size=embedding_service.get_embedding_dimension()
)
print("✓ Services initialized successfully")
# Pydantic models
class SearchRequest(BaseModel):
text: Optional[str] = None
limit: int = 10
score_threshold: Optional[float] = None
text_weight: float = 0.5
image_weight: float = 0.5
class SearchResponse(BaseModel):
id: str
confidence: float
metadata: dict
class IndexResponse(BaseModel):
success: bool
id: str
message: str
@app.get("/")
async def root():
"""Health check endpoint"""
return {
"status": "running",
"service": "Event Social Media Embeddings API",
"embedding_model": "Jina CLIP v2",
"vector_db": "Qdrant",
"language_support": "Vietnamese + 88 other languages"
}
@app.post("/index", response_model=IndexResponse)
async def index_data(
id: str = Form(...),
text: str = Form(...),
image: Optional[UploadFile] = File(None)
):
"""
Index data vào vector database
Body:
- id: Document ID (event ID, post ID, etc.)
- text: Text content (tiếng Việt supported)
- image: Image file (optional)
Returns:
- success: True/False
- id: Document ID
- message: Status message
"""
try:
# Prepare embeddings
text_embedding = None
image_embedding = None
# Encode text (tiếng Việt)
if text and text.strip():
text_embedding = embedding_service.encode_text(text)
# Encode image nếu có
if image:
image_bytes = await image.read()
pil_image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
image_embedding = embedding_service.encode_image(pil_image)
# Combine embeddings
if text_embedding is not None and image_embedding is not None:
# Average của text và image embeddings
combined_embedding = np.mean([text_embedding, image_embedding], axis=0)
elif text_embedding is not None:
combined_embedding = text_embedding
elif image_embedding is not None:
combined_embedding = image_embedding
else:
raise HTTPException(status_code=400, detail="Phải cung cấp ít nhất text hoặc image")
# Normalize
combined_embedding = combined_embedding / np.linalg.norm(combined_embedding, axis=1, keepdims=True)
# Index vào Qdrant
metadata = {
"text": text,
"has_image": image is not None,
"image_filename": image.filename if image else None
}
result = qdrant_service.index_data(
doc_id=id,
embedding=combined_embedding,
metadata=metadata
)
return IndexResponse(
success=True,
id=result["original_id"], # Trả về MongoDB ObjectId
message=f"Đã index thành công document {result['original_id']} (Qdrant UUID: {result['qdrant_id']})"
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Lỗi khi index: {str(e)}")
@app.post("/search", response_model=List[SearchResponse])
async def search(
text: Optional[str] = Form(None),
image: Optional[UploadFile] = File(None),
limit: int = Form(10),
score_threshold: Optional[float] = Form(None),
text_weight: float = Form(0.5),
image_weight: float = Form(0.5)
):
"""
Search similar documents bằng text và/hoặc image
Body:
- text: Query text (tiếng Việt supported)
- image: Query image (optional)
- limit: Số lượng kết quả (default: 10)
- score_threshold: Minimum confidence score (0-1)
- text_weight: Weight cho text search (default: 0.5)
- image_weight: Weight cho image search (default: 0.5)
Returns:
- List of results với id, confidence, và metadata
"""
try:
# Prepare query embeddings
text_embedding = None
image_embedding = None
# Encode text query
if text and text.strip():
text_embedding = embedding_service.encode_text(text)
# Encode image query
if image:
image_bytes = await image.read()
pil_image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
image_embedding = embedding_service.encode_image(pil_image)
# Validate input
if text_embedding is None and image_embedding is None:
raise HTTPException(status_code=400, detail="Phải cung cấp ít nhất text hoặc image để search")
# Hybrid search với Qdrant
results = qdrant_service.hybrid_search(
text_embedding=text_embedding,
image_embedding=image_embedding,
text_weight=text_weight,
image_weight=image_weight,
limit=limit,
score_threshold=score_threshold,
ef=256 # High accuracy search
)
# Format response
return [
SearchResponse(
id=result["id"],
confidence=result["confidence"],
metadata=result["metadata"]
)
for result in results
]
except Exception as e:
raise HTTPException(status_code=500, detail=f"Lỗi khi search: {str(e)}")
@app.post("/search/text", response_model=List[SearchResponse])
async def search_by_text(
text: str = Form(...),
limit: int = Form(10),
score_threshold: Optional[float] = Form(None)
):
"""
Search chỉ bằng text (tiếng Việt)
Body:
- text: Query text (tiếng Việt)
- limit: Số lượng kết quả
- score_threshold: Minimum confidence score
Returns:
- List of results
"""
try:
# Encode text
text_embedding = embedding_service.encode_text(text)
# Search
results = qdrant_service.search(
query_embedding=text_embedding,
limit=limit,
score_threshold=score_threshold,
ef=256
)
return [
SearchResponse(
id=result["id"],
confidence=result["confidence"],
metadata=result["metadata"]
)
for result in results
]
except Exception as e:
raise HTTPException(status_code=500, detail=f"Lỗi khi search: {str(e)}")
@app.post("/search/image", response_model=List[SearchResponse])
async def search_by_image(
image: UploadFile = File(...),
limit: int = Form(10),
score_threshold: Optional[float] = Form(None)
):
"""
Search chỉ bằng image
Body:
- image: Query image
- limit: Số lượng kết quả
- score_threshold: Minimum confidence score
Returns:
- List of results
"""
try:
# Encode image
image_bytes = await image.read()
pil_image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
image_embedding = embedding_service.encode_image(pil_image)
# Search
results = qdrant_service.search(
query_embedding=image_embedding,
limit=limit,
score_threshold=score_threshold,
ef=256
)
return [
SearchResponse(
id=result["id"],
confidence=result["confidence"],
metadata=result["metadata"]
)
for result in results
]
except Exception as e:
raise HTTPException(status_code=500, detail=f"Lỗi khi search: {str(e)}")
@app.delete("/delete/{doc_id}")
async def delete_document(doc_id: str):
"""
Delete document by ID (MongoDB ObjectId hoặc UUID)
Args:
- doc_id: Document ID to delete
Returns:
- Success message
"""
try:
qdrant_service.delete_by_id(doc_id)
return {"success": True, "message": f"Đã xóa document {doc_id}"}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Lỗi khi xóa: {str(e)}")
@app.get("/document/{doc_id}")
async def get_document(doc_id: str):
"""
Get document by ID (MongoDB ObjectId hoặc UUID)
Args:
- doc_id: Document ID (MongoDB ObjectId)
Returns:
- Document data
"""
try:
doc = qdrant_service.get_by_id(doc_id)
if doc:
return {
"success": True,
"data": doc
}
raise HTTPException(status_code=404, detail=f"Không tìm thấy document {doc_id}")
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Lỗi khi get document: {str(e)}")
@app.get("/stats")
async def get_stats():
"""
Lấy thông tin thống kê collection
Returns:
- Collection statistics
"""
try:
info = qdrant_service.get_collection_info()
return info
except Exception as e:
raise HTTPException(status_code=500, detail=f"Lỗi khi lấy stats: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=8000,
log_level="info"
)