|
|
from fastapi import FastAPI, HTTPException, Request |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import torch |
|
|
import os |
|
|
import time |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="IT Job Classifier API", |
|
|
description="API для определения IT-вакансий по тексту", |
|
|
version="1.0.0", |
|
|
docs_url="/docs", |
|
|
redoc_url="/redoc", |
|
|
openapi_url="/openapi.json" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
class TextRequest(BaseModel): |
|
|
text: str |
|
|
|
|
|
class PredictionResponse(BaseModel): |
|
|
prediction: str |
|
|
is_it_job: bool |
|
|
confidence: float |
|
|
class_id: int |
|
|
|
|
|
|
|
|
MODEL_NAME = "MrAlexGov/BERT-AI-Vacancy" |
|
|
model = None |
|
|
tokenizer = None |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def load_model(): |
|
|
"""Загрузка модели при старте приложения""" |
|
|
global model, tokenizer |
|
|
|
|
|
print(f"🚀 Загружаем модель {MODEL_NAME}...") |
|
|
start_time = time.time() |
|
|
|
|
|
try: |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) |
|
|
|
|
|
|
|
|
model.eval() |
|
|
|
|
|
load_time = time.time() - start_time |
|
|
print(f"✅ Модель успешно загружена за {load_time:.2f} секунд") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Ошибка загрузки модели: {str(e)}") |
|
|
raise e |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""Корневой эндпоинт с документацией""" |
|
|
return { |
|
|
"service": "IT Job Classification API", |
|
|
"status": "running", |
|
|
"version": "1.0.0", |
|
|
"endpoints": { |
|
|
"GET /": "Эта документация", |
|
|
"GET /health": "Проверка здоровья сервиса", |
|
|
"POST /predict": "Классификация текста", |
|
|
"GET /docs": "Интерактивная документация (Swagger UI)", |
|
|
"GET /redoc": "Альтернативная документация" |
|
|
}, |
|
|
"model": { |
|
|
"name": MODEL_NAME, |
|
|
"status": "loaded" if model else "loading" |
|
|
}, |
|
|
"usage": { |
|
|
"curl_example": 'curl -X POST https://your-space.hf.space/predict -H "Content-Type: application/json" -d \'{"text": "Ищем Python разработчика"}\'' |
|
|
} |
|
|
} |
|
|
|
|
|
@app.get("/health") |
|
|
async def health(): |
|
|
"""Проверка здоровья сервиса""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"timestamp": time.time(), |
|
|
"model_loaded": model is not None, |
|
|
"service": "IT Job Classifier" |
|
|
} |
|
|
|
|
|
@app.post("/predict", response_model=PredictionResponse) |
|
|
async def predict(request: TextRequest): |
|
|
"""Классифицирует текст на IT/не IT вакансию""" |
|
|
|
|
|
|
|
|
if model is None or tokenizer is None: |
|
|
raise HTTPException( |
|
|
status_code=503, |
|
|
detail="Модель еще не загружена. Пожалуйста, подождите." |
|
|
) |
|
|
|
|
|
|
|
|
if not request.text or not request.text.strip(): |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail="Текст не может быть пустым" |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
inputs = tokenizer( |
|
|
request.text, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
padding=True |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
|
prediction_id = torch.argmax(probabilities, dim=-1).item() |
|
|
confidence = probabilities[0][prediction_id].item() |
|
|
|
|
|
|
|
|
|
|
|
is_it_job = prediction_id == 1 |
|
|
prediction_label = "IT-вакансия" if is_it_job else "Не IT-вакансия" |
|
|
|
|
|
return PredictionResponse( |
|
|
prediction=prediction_label, |
|
|
is_it_job=is_it_job, |
|
|
confidence=confidence, |
|
|
class_id=prediction_id |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail=f"Ошибка обработки: {str(e)}" |
|
|
) |
|
|
|
|
|
|
|
|
@app.get("/test") |
|
|
async def test(): |
|
|
"""Тестовый эндпоинт""" |
|
|
return { |
|
|
"message": "Сервис работает", |
|
|
"timestamp": time.time(), |
|
|
"model": MODEL_NAME |
|
|
} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
|
|
|
|
|
|
port = int(os.getenv("PORT", 7860)) |
|
|
|
|
|
uvicorn.run( |
|
|
"app:app", |
|
|
host="0.0.0.0", |
|
|
port=port, |
|
|
reload=False |
|
|
) |