MrAlexGov's picture
Update app.py
897c861 verified
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import os
import time
# Инициализация приложения
app = FastAPI(
title="IT Job Classifier API",
description="API для определения IT-вакансий по тексту",
version="1.0.0",
docs_url="/docs", # Явно включаем документацию
redoc_url="/redoc",
openapi_url="/openapi.json"
)
# Добавляем CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Модели Pydantic
class TextRequest(BaseModel):
text: str
class PredictionResponse(BaseModel):
prediction: str
is_it_job: bool
confidence: float
class_id: int
# Глобальные переменные для модели
MODEL_NAME = "MrAlexGov/BERT-AI-Vacancy"
model = None
tokenizer = None
@app.on_event("startup")
async def load_model():
"""Загрузка модели при старте приложения"""
global model, tokenizer
print(f"🚀 Загружаем модель {MODEL_NAME}...")
start_time = time.time()
try:
# Загружаем токенизатор и модель
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
# Переводим модель в режим оценки
model.eval()
load_time = time.time() - start_time
print(f"✅ Модель успешно загружена за {load_time:.2f} секунд")
except Exception as e:
print(f"❌ Ошибка загрузки модели: {str(e)}")
raise e
@app.get("/")
async def root():
"""Корневой эндпоинт с документацией"""
return {
"service": "IT Job Classification API",
"status": "running",
"version": "1.0.0",
"endpoints": {
"GET /": "Эта документация",
"GET /health": "Проверка здоровья сервиса",
"POST /predict": "Классификация текста",
"GET /docs": "Интерактивная документация (Swagger UI)",
"GET /redoc": "Альтернативная документация"
},
"model": {
"name": MODEL_NAME,
"status": "loaded" if model else "loading"
},
"usage": {
"curl_example": 'curl -X POST https://your-space.hf.space/predict -H "Content-Type: application/json" -d \'{"text": "Ищем Python разработчика"}\''
}
}
@app.get("/health")
async def health():
"""Проверка здоровья сервиса"""
return {
"status": "healthy",
"timestamp": time.time(),
"model_loaded": model is not None,
"service": "IT Job Classifier"
}
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: TextRequest):
"""Классифицирует текст на IT/не IT вакансию"""
# Проверяем, загружена ли модель
if model is None or tokenizer is None:
raise HTTPException(
status_code=503,
detail="Модель еще не загружена. Пожалуйста, подождите."
)
# Проверяем входные данные
if not request.text or not request.text.strip():
raise HTTPException(
status_code=400,
detail="Текст не может быть пустым"
)
try:
# Токенизация
inputs = tokenizer(
request.text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
# Предсказание
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
prediction_id = torch.argmax(probabilities, dim=-1).item()
confidence = probabilities[0][prediction_id].item()
# Интерпретация результата
# Предполагаем, что label_1 - это IT-вакансия
is_it_job = prediction_id == 1
prediction_label = "IT-вакансия" if is_it_job else "Не IT-вакансия"
return PredictionResponse(
prediction=prediction_label,
is_it_job=is_it_job,
confidence=confidence,
class_id=prediction_id
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Ошибка обработки: {str(e)}"
)
# Эндпоинт для тестирования
@app.get("/test")
async def test():
"""Тестовый эндпоинт"""
return {
"message": "Сервис работает",
"timestamp": time.time(),
"model": MODEL_NAME
}
# Если запускаем напрямую (не через Docker на Spaces)
if __name__ == "__main__":
import uvicorn
# Получаем порт из переменных окружения (для Hugging Face Spaces)
port = int(os.getenv("PORT", 7860))
uvicorn.run(
"app:app",
host="0.0.0.0",
port=port,
reload=False # Отключаем reload в production
)