File size: 3,353 Bytes
7af1417
384d408
 
 
 
2b748ab
384d408
2b748ab
 
 
 
d98e8f9
 
620e2b8
d98e8f9
7af1417
d98e8f9
2b748ab
 
 
 
d98e8f9
2b748ab
d98e8f9
2b748ab
bd4f63a
d98e8f9
 
620e2b8
d98e8f9
620e2b8
2b748ab
7c003e4
2b748ab
620e2b8
bd4f63a
d98e8f9
620e2b8
2b748ab
 
0f3580d
620e2b8
0f3580d
 
620e2b8
384d408
d98e8f9
620e2b8
 
2b748ab
620e2b8
 
 
 
2b748ab
620e2b8
 
2b748ab
 
 
 
d98e8f9
 
 
 
 
 
 
 
 
 
 
 
384d408
620e2b8
 
d98e8f9
 
 
 
620e2b8
 
 
d98e8f9
620e2b8
 
d98e8f9
620e2b8
 
 
d98e8f9
620e2b8
 
e796c15
d98e8f9
620e2b8
 
e796c15
 
 
 
 
bd4f63a
620e2b8
d98e8f9
620e2b8
bd4f63a
d98e8f9
620e2b8
2b748ab
d98e8f9
2b748ab
d98e8f9
2b748ab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import torch
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import logging

# Log ayarları
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# HF Spaces otomatik olarak HF_TOKEN sağlar
HF_TOKEN = os.getenv("HF_TOKEN", None)
if not HF_TOKEN:
    logger.warning("HF_TOKEN bulunamadı! Genel modellerle çalışılacak")

# Model konfigürasyonu (HF Spaces için optimize)
MODEL_CONFIG = {
    "base_model": "google/gemma-1.1-2b-it",
    "lora_model": "programci48/heytak-lora-v1",
    "cache_dir": "/tmp/huggingface",
    "offload_folder": "/tmp/offload",
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32
}

def load_models():
    """HF Spaces için optimize edilmiş model yükleme"""
    try:
        # Tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            MODEL_CONFIG["base_model"],
            token=HF_TOKEN,
            cache_dir=MODEL_CONFIG["cache_dir"]
        )

        # Model
        base_model = AutoModelForCausalLM.from_pretrained(
            MODEL_CONFIG["base_model"],
            torch_dtype=MODEL_CONFIG["torch_dtype"],
            device_map="auto" if MODEL_CONFIG["device"] == "cuda" else None,
            token=HF_TOKEN,
            cache_dir=MODEL_CONFIG["cache_dir"],
            offload_folder=MODEL_CONFIG["offload_folder"]
        )

        # LoRA Adaptörü
        model = PeftModel.from_pretrained(
            base_model,
            MODEL_CONFIG["lora_model"],
            token=HF_TOKEN
        )
        model.eval()

        return {"tokenizer": tokenizer, "model": model}

    except Exception as e:
        logger.error(f"Model yükleme hatası: {str(e)}")
        raise

# Uygulama başlatma
app = FastAPI(title="HeyTak AI API")

@app.on_event("startup")
async def startup_event():
    try:
        app.state.models = load_models()
        logger.info("Modeller başarıyla yüklendi!")
    except Exception as e:
        logger.critical(f"Başlatma hatası: {str(e)}")
        raise

@app.post("/predict")
async def predict(request: Request):
    try:
        data = await request.json()
        prompt = data.get("inputs", "")
        
        inputs = app.state.models["tokenizer"](
            prompt, 
            return_tensors="pt",
            truncation=True,
            max_length=512
        ).to(app.state.models["model"].device)

        with torch.no_grad():
            outputs = app.state.models["model"].generate(
                **inputs,
                max_new_tokens=100,
                temperature=0.7,
                top_p=0.9
            )

        full_response = app.state.models["tokenizer"].decode(
            outputs[0], 
            skip_special_tokens=True
        ).strip()
        
        # Sadece modelin ürettiği kısmı al (prompt'u çıkar)
        generated_text = full_response[len(prompt):].strip()
        
        return {"generated_text": generated_text}

    except Exception as e:
        logger.error(f"Tahmin hatası: {str(e)}")
        return {"error": str(e)}, 500

@app.get("/")
async def health_check():
    return {
        "status": "active",
        "device": MODEL_CONFIG["device"],
        "framework": "FastAPI"
    }