programci48 commited on
Commit
2b748ab
·
verified ·
1 Parent(s): bc25ec1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -38
app.py CHANGED
@@ -4,74 +4,82 @@ from fastapi import FastAPI, Request
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from peft import PeftModel
6
  from typing import Dict, Any
 
7
 
8
- # Hugging Face token
 
 
 
 
9
  HF_TOKEN = os.getenv("HF_TOKEN")
10
  if not HF_TOKEN:
 
11
  raise ValueError("HF_TOKEN environment variable not set!")
12
 
13
- # Cache dizinini ayarla (yazma izni olan bir dizin)
14
- os.environ["HF_HOME"] = "/tmp/huggingface"
15
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
16
-
17
- # Model IDs
18
- BASE_MODEL_ID = "google/gemma-1.1-2b-it"
19
- LORA_MODEL_ID = "programci48/heytak-lora-v1"
 
20
 
21
- # Load models with error handling and optimizations
22
  def load_models() -> Dict[str, Any]:
 
23
  try:
24
- # Load tokenizer (login işlemi olmadan doğrudan token kullanarak)
25
  tokenizer = AutoTokenizer.from_pretrained(
26
- BASE_MODEL_ID,
27
  token=HF_TOKEN,
28
- cache_dir="/tmp/huggingface"
29
  )
30
 
31
- # Load base model with memory optimization
32
  base_model = AutoModelForCausalLM.from_pretrained(
33
- BASE_MODEL_ID,
34
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
35
  device_map="auto",
36
  token=HF_TOKEN,
37
  low_cpu_mem_usage=True,
38
- cache_dir="/tmp/huggingface"
39
  )
40
 
41
- # Load LoRA adapter
42
  model = PeftModel.from_pretrained(
43
  base_model,
44
- LORA_MODEL_ID,
45
  token=HF_TOKEN
46
  )
47
  model.eval()
48
-
49
- # Move to CPU if no GPU available
50
- if not torch.cuda.is_available():
51
  model = model.to("cpu")
52
- print("Model moved to CPU")
53
 
54
- return {
55
- "tokenizer": tokenizer,
56
- "model": model
57
- }
58
 
59
  except Exception as e:
60
- raise RuntimeError(f"Model loading failed: {str(e)}")
61
-
62
- # Initialize models
63
- models = load_models()
64
-
65
- # FastAPI app
66
- app = FastAPI(title="Gemma-LoRA API")
67
-
 
 
 
 
68
  @app.post("/run/predict")
69
  async def predict(request: Request):
70
  try:
71
  data = await request.json()
72
  prompt = data["data"][0]
 
73
 
74
- # Tokenize with truncation
75
  inputs = models["tokenizer"](
76
  prompt,
77
  return_tensors="pt",
@@ -79,7 +87,6 @@ async def predict(request: Request):
79
  max_length=512
80
  ).to(models["model"].device)
81
 
82
- # Generate response
83
  with torch.no_grad():
84
  outputs = models["model"].generate(
85
  **inputs,
@@ -90,17 +97,22 @@ async def predict(request: Request):
90
  repetition_penalty=1.1
91
  )
92
 
93
- # Decode and clean response
94
  response = models["tokenizer"].decode(
95
  outputs[0],
96
  skip_special_tokens=True
97
  ).strip()
98
 
 
99
  return {"data": [response]}
100
 
101
  except Exception as e:
 
102
  return {"error": str(e)}, 500
103
 
104
  @app.get("/health")
105
  async def health_check():
106
- return {"status": "healthy"}
 
 
 
 
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from peft import PeftModel
6
  from typing import Dict, Any
7
+ import logging
8
 
9
+ # Log ayarları
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # Ortam değişkenleri ve konfigürasyon
14
  HF_TOKEN = os.getenv("HF_TOKEN")
15
  if not HF_TOKEN:
16
+ logger.error("HF_TOKEN environment variable not set!")
17
  raise ValueError("HF_TOKEN environment variable not set!")
18
 
19
+ # Model konfigürasyonu
20
+ MODEL_CONFIG = {
21
+ "base_model": "google/gemma-1.1-2b-it",
22
+ "lora_model": "programci48/heytak-lora-v1",
23
+ "cache_dir": "/tmp/huggingface",
24
+ "device": "cuda" if torch.cuda.is_available() else "cpu",
25
+ "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32
26
+ }
27
 
 
28
  def load_models() -> Dict[str, Any]:
29
+ """Modelleri yükleyen fonksiyon"""
30
  try:
31
+ logger.info("Tokenizer yükleniyor...")
32
  tokenizer = AutoTokenizer.from_pretrained(
33
+ MODEL_CONFIG["base_model"],
34
  token=HF_TOKEN,
35
+ cache_dir=MODEL_CONFIG["cache_dir"]
36
  )
37
 
38
+ logger.info(f"Temel model yükleniyor ({MODEL_CONFIG['device']})...")
39
  base_model = AutoModelForCausalLM.from_pretrained(
40
+ MODEL_CONFIG["base_model"],
41
+ torch_dtype=MODEL_CONFIG["torch_dtype"],
42
  device_map="auto",
43
  token=HF_TOKEN,
44
  low_cpu_mem_usage=True,
45
+ cache_dir=MODEL_CONFIG["cache_dir"]
46
  )
47
 
48
+ logger.info("LoRA adaptörü yükleniyor...")
49
  model = PeftModel.from_pretrained(
50
  base_model,
51
+ MODEL_CONFIG["lora_model"],
52
  token=HF_TOKEN
53
  )
54
  model.eval()
55
+
56
+ if MODEL_CONFIG["device"] == "cpu":
 
57
  model = model.to("cpu")
58
+ torch.cuda.empty_cache()
59
 
60
+ logger.info("Modeller başarıyla yüklendi!")
61
+ return {"tokenizer": tokenizer, "model": model}
 
 
62
 
63
  except Exception as e:
64
+ logger.error(f"Model yükleme hatası: {str(e)}")
65
+ raise
66
+
67
+ # Uygulama başlatma
68
+ try:
69
+ models = load_models()
70
+ app = FastAPI(title="Gemma-LoRA API", version="1.0")
71
+ except Exception as e:
72
+ logger.critical(f"Uygulama başlatılamadı: {str(e)}")
73
+ raise
74
+
75
+ # API Endpoint'leri
76
  @app.post("/run/predict")
77
  async def predict(request: Request):
78
  try:
79
  data = await request.json()
80
  prompt = data["data"][0]
81
+ logger.info(f"Gelen istek: {prompt[:50]}...")
82
 
 
83
  inputs = models["tokenizer"](
84
  prompt,
85
  return_tensors="pt",
 
87
  max_length=512
88
  ).to(models["model"].device)
89
 
 
90
  with torch.no_grad():
91
  outputs = models["model"].generate(
92
  **inputs,
 
97
  repetition_penalty=1.1
98
  )
99
 
 
100
  response = models["tokenizer"].decode(
101
  outputs[0],
102
  skip_special_tokens=True
103
  ).strip()
104
 
105
+ logger.info(f"Oluşturulan yanıt: {response[:50]}...")
106
  return {"data": [response]}
107
 
108
  except Exception as e:
109
+ logger.error(f"İşlem hatası: {str(e)}")
110
  return {"error": str(e)}, 500
111
 
112
  @app.get("/health")
113
  async def health_check():
114
+ return {
115
+ "status": "healthy",
116
+ "device": MODEL_CONFIG["device"],
117
+ "torch_dtype": str(MODEL_CONFIG["torch_dtype"])
118
+ }