Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from fastapi import FastAPI, Request | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| import logging | |
| # Log ayarları | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # HF Spaces otomatik olarak HF_TOKEN sağlar | |
| HF_TOKEN = os.getenv("HF_TOKEN", None) | |
| if not HF_TOKEN: | |
| logger.warning("HF_TOKEN bulunamadı! Genel modellerle çalışılacak") | |
| # Model konfigürasyonu (HF Spaces için optimize) | |
| MODEL_CONFIG = { | |
| "base_model": "google/gemma-1.1-2b-it", | |
| "lora_model": "programci48/heytak-lora-v1", | |
| "cache_dir": "/tmp/huggingface", | |
| "offload_folder": "/tmp/offload", | |
| "device": "cuda" if torch.cuda.is_available() else "cpu", | |
| "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32 | |
| } | |
| def load_models(): | |
| """HF Spaces için optimize edilmiş model yükleme""" | |
| try: | |
| # Tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_CONFIG["base_model"], | |
| token=HF_TOKEN, | |
| cache_dir=MODEL_CONFIG["cache_dir"] | |
| ) | |
| # Model | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_CONFIG["base_model"], | |
| torch_dtype=MODEL_CONFIG["torch_dtype"], | |
| device_map="auto" if MODEL_CONFIG["device"] == "cuda" else None, | |
| token=HF_TOKEN, | |
| cache_dir=MODEL_CONFIG["cache_dir"], | |
| offload_folder=MODEL_CONFIG["offload_folder"] | |
| ) | |
| # LoRA Adaptörü | |
| model = PeftModel.from_pretrained( | |
| base_model, | |
| MODEL_CONFIG["lora_model"], | |
| token=HF_TOKEN | |
| ) | |
| model.eval() | |
| return {"tokenizer": tokenizer, "model": model} | |
| except Exception as e: | |
| logger.error(f"Model yükleme hatası: {str(e)}") | |
| raise | |
| # Uygulama başlatma | |
| app = FastAPI(title="HeyTak AI API") | |
| async def startup_event(): | |
| try: | |
| app.state.models = load_models() | |
| logger.info("Modeller başarıyla yüklendi!") | |
| except Exception as e: | |
| logger.critical(f"Başlatma hatası: {str(e)}") | |
| raise | |
| async def predict(request: Request): | |
| try: | |
| data = await request.json() | |
| prompt = data.get("inputs", "") | |
| inputs = app.state.models["tokenizer"]( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512 | |
| ).to(app.state.models["model"].device) | |
| with torch.no_grad(): | |
| outputs = app.state.models["model"].generate( | |
| **inputs, | |
| max_new_tokens=100, | |
| temperature=0.7, | |
| top_p=0.9 | |
| ) | |
| full_response = app.state.models["tokenizer"].decode( | |
| outputs[0], | |
| skip_special_tokens=True | |
| ).strip() | |
| # Sadece modelin ürettiği kısmı al (prompt'u çıkar) | |
| generated_text = full_response[len(prompt):].strip() | |
| return {"generated_text": generated_text} | |
| except Exception as e: | |
| logger.error(f"Tahmin hatası: {str(e)}") | |
| return {"error": str(e)}, 500 | |
| async def health_check(): | |
| return { | |
| "status": "active", | |
| "device": MODEL_CONFIG["device"], | |
| "framework": "FastAPI" | |
| } |