heytak-ai-api / app.py
programci48's picture
Update app.py
e796c15 verified
import os
import torch
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import logging
# Log ayarları
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# HF Spaces otomatik olarak HF_TOKEN sağlar
HF_TOKEN = os.getenv("HF_TOKEN", None)
if not HF_TOKEN:
logger.warning("HF_TOKEN bulunamadı! Genel modellerle çalışılacak")
# Model konfigürasyonu (HF Spaces için optimize)
MODEL_CONFIG = {
"base_model": "google/gemma-1.1-2b-it",
"lora_model": "programci48/heytak-lora-v1",
"cache_dir": "/tmp/huggingface",
"offload_folder": "/tmp/offload",
"device": "cuda" if torch.cuda.is_available() else "cpu",
"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32
}
def load_models():
"""HF Spaces için optimize edilmiş model yükleme"""
try:
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(
MODEL_CONFIG["base_model"],
token=HF_TOKEN,
cache_dir=MODEL_CONFIG["cache_dir"]
)
# Model
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_CONFIG["base_model"],
torch_dtype=MODEL_CONFIG["torch_dtype"],
device_map="auto" if MODEL_CONFIG["device"] == "cuda" else None,
token=HF_TOKEN,
cache_dir=MODEL_CONFIG["cache_dir"],
offload_folder=MODEL_CONFIG["offload_folder"]
)
# LoRA Adaptörü
model = PeftModel.from_pretrained(
base_model,
MODEL_CONFIG["lora_model"],
token=HF_TOKEN
)
model.eval()
return {"tokenizer": tokenizer, "model": model}
except Exception as e:
logger.error(f"Model yükleme hatası: {str(e)}")
raise
# Uygulama başlatma
app = FastAPI(title="HeyTak AI API")
@app.on_event("startup")
async def startup_event():
try:
app.state.models = load_models()
logger.info("Modeller başarıyla yüklendi!")
except Exception as e:
logger.critical(f"Başlatma hatası: {str(e)}")
raise
@app.post("/predict")
async def predict(request: Request):
try:
data = await request.json()
prompt = data.get("inputs", "")
inputs = app.state.models["tokenizer"](
prompt,
return_tensors="pt",
truncation=True,
max_length=512
).to(app.state.models["model"].device)
with torch.no_grad():
outputs = app.state.models["model"].generate(
**inputs,
max_new_tokens=100,
temperature=0.7,
top_p=0.9
)
full_response = app.state.models["tokenizer"].decode(
outputs[0],
skip_special_tokens=True
).strip()
# Sadece modelin ürettiği kısmı al (prompt'u çıkar)
generated_text = full_response[len(prompt):].strip()
return {"generated_text": generated_text}
except Exception as e:
logger.error(f"Tahmin hatası: {str(e)}")
return {"error": str(e)}, 500
@app.get("/")
async def health_check():
return {
"status": "active",
"device": MODEL_CONFIG["device"],
"framework": "FastAPI"
}