opex792 commited on
Commit
b7271ae
·
verified ·
1 Parent(s): 7d62568

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +21 -20
main.py CHANGED
@@ -7,7 +7,8 @@ from pydub import AudioSegment
7
  from fastapi import FastAPI, UploadFile, File, Form, Depends, HTTPException, status
8
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
9
  from fastapi.concurrency import run_in_threadpool
10
- from pydantic import BaseModel, Field, BaseSettings
 
11
  from typing import List, Dict, Optional, Tuple, Any
12
  import tempfile
13
  import uvicorn
@@ -21,7 +22,7 @@ class Settings(BaseSettings):
21
  log_level: str = "INFO"
22
 
23
  class Config:
24
- env_file = ".env"
25
  env_file_encoding = 'utf-8'
26
 
27
  settings = Settings()
@@ -160,16 +161,16 @@ class ModalityType(str):
160
  IMU = "imu"
161
 
162
  class EmbeddingItem(BaseModel):
163
- id: str = Field(..., description="Identifier of the item (e.g., filename or text content)")
164
- embedding: List[float] = Field(..., description="The computed embedding vector")
165
 
166
  class EmbeddingPayload(BaseModel):
167
- vision: Optional[List[EmbeddingItem]] = Field(None, description="List of vision embeddings")
168
- audio: Optional[List[EmbeddingItem]] = Field(None, description="List of audio embeddings")
169
- text: Optional[List[EmbeddingItem]] = Field(None, description="List of text embeddings")
170
- depth: Optional[List[EmbeddingItem]] = Field(None, description="List of depth embeddings (future support)")
171
- thermal: Optional[List[EmbeddingItem]] = Field(None, description="List of thermal embeddings (future support)")
172
- imu: Optional[List[EmbeddingItem]] = Field(None, description="List of IMU embeddings (future support)")
173
 
174
  class EmbeddingResponse(BaseModel):
175
  embeddings: EmbeddingPayload
@@ -180,15 +181,15 @@ class SimilarityMatch(BaseModel):
180
  item_b_id: str
181
  modality_a: ModalityType
182
  modality_b: ModalityType
183
- score: float = Field(..., ge=0.0, le=1.0001)
184
 
185
  class SimilarityRequest(BaseModel):
186
- embeddings_payload: EmbeddingPayload = Field(..., description="Payload containing embeddings from the /compute_embeddings endpoint")
187
- threshold: float = Field(0.5, ge=0.0, le=1.0, description="Minimum similarity score to include in results")
188
- top_k: Optional[int] = Field(None, gt=0, description="Maximum number of matches to return per modality pair comparison. If None, all matches above threshold are returned.")
189
- normalize_scores: bool = Field(True, description="Whether to normalize embeddings before computing cosine similarity (recommended)")
190
- compare_within_modalities: bool = Field(True, description="Compare items within the same modality (e.g., image1 vs image2)")
191
- compare_across_modalities: bool = Field(True, description="Compare items across different modalities (e.g., image1 vs text1)")
192
 
193
  class SimilarityResponse(BaseModel):
194
  matches: List[SimilarityMatch]
@@ -197,9 +198,9 @@ class SimilarityResponse(BaseModel):
197
 
198
  @app.post("/compute_embeddings", response_model=EmbeddingResponse, dependencies=[Depends(verify_token)])
199
  async def generate_embeddings_endpoint(
200
- texts: Optional[List[str]] = Form(None, description="List of text strings to embed."),
201
- images: Optional[List[UploadFile]] = File(default=None, description="List of image files."),
202
- audio_files: Optional[List[UploadFile]] = File(default=None, description="List of audio files (MP3, WAV, etc.).")
203
  ):
204
  if embedding_manager is None:
205
  raise HTTPException(status_code=503, detail="Embedding manager not initialized.")
 
7
  from fastapi import FastAPI, UploadFile, File, Form, Depends, HTTPException, status
8
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
9
  from fastapi.concurrency import run_in_threadpool
10
+ from pydantic import BaseModel, Field # Убрали BaseSettings отсюда
11
+ from pydantic_settings import BaseSettings # <--- ИЗМЕНЕННЫЙ ИМПОРТ
12
  from typing import List, Dict, Optional, Tuple, Any
13
  import tempfile
14
  import uvicorn
 
22
  log_level: str = "INFO"
23
 
24
  class Config:
25
+ env_file = ".env"
26
  env_file_encoding = 'utf-8'
27
 
28
  settings = Settings()
 
161
  IMU = "imu"
162
 
163
  class EmbeddingItem(BaseModel):
164
+ id: str
165
+ embedding: List[float]
166
 
167
  class EmbeddingPayload(BaseModel):
168
+ vision: Optional[List[EmbeddingItem]] = None
169
+ audio: Optional[List[EmbeddingItem]] = None
170
+ text: Optional[List[EmbeddingItem]] = None
171
+ depth: Optional[List[EmbeddingItem]] = None
172
+ thermal: Optional[List[EmbeddingItem]] = None
173
+ imu: Optional[List[EmbeddingItem]] = None
174
 
175
  class EmbeddingResponse(BaseModel):
176
  embeddings: EmbeddingPayload
 
181
  item_b_id: str
182
  modality_a: ModalityType
183
  modality_b: ModalityType
184
+ score: float
185
 
186
  class SimilarityRequest(BaseModel):
187
+ embeddings_payload: EmbeddingPayload
188
+ threshold: float = 0.5
189
+ top_k: Optional[int] = None
190
+ normalize_scores: bool = True
191
+ compare_within_modalities: bool = True
192
+ compare_across_modalities: bool = True
193
 
194
  class SimilarityResponse(BaseModel):
195
  matches: List[SimilarityMatch]
 
198
 
199
  @app.post("/compute_embeddings", response_model=EmbeddingResponse, dependencies=[Depends(verify_token)])
200
  async def generate_embeddings_endpoint(
201
+ texts: Optional[List[str]] = Form(None),
202
+ images: Optional[List[UploadFile]] = File(default=None),
203
+ audio_files: Optional[List[UploadFile]] = File(default=None)
204
  ):
205
  if embedding_manager is None:
206
  raise HTTPException(status_code=503, detail="Embedding manager not initialized.")