programci48 commited on
Commit
620e2b8
·
verified ·
1 Parent(s): b511652

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -26
app.py CHANGED
@@ -3,40 +3,103 @@ import torch
3
  from fastapi import FastAPI, Request
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from peft import PeftModel
 
 
6
 
7
- # Hugging Face token (gated modeller için gerekli)
8
- hf_token = os.getenv("HF_TOKEN")
9
- print("HF_TOKEN:", hf_token) # Çıktı logs içinde görünür
 
10
 
11
- # Model ID'leri
12
- base_model_id = "google/gemma-1.1-2b-it"
13
- lora_model_id = "programci48/heytak-lora-v1"
14
 
15
- # Tokenizer ve model yükleme
16
- tokenizer = AutoTokenizer.from_pretrained(base_model_id, token=hf_token)
 
17
 
18
- base_model = AutoModelForCausalLM.from_pretrained(
19
- base_model_id,
20
- torch_dtype=torch.float32,
21
- device_map=None, # Hugging Face CPU ortamı için GPU ayarı yapılmaz
22
- token=hf_token
23
- )
 
 
24
 
25
- model = PeftModel.from_pretrained(base_model, lora_model_id, token=hf_token)
26
- model.eval()
 
 
 
 
 
 
 
27
 
28
- # FastAPI uygulaması
29
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  @app.post("/run/predict")
32
  async def predict(request: Request):
33
- data = await request.json()
34
- prompt = data["data"][0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # Model ile yanıt üret
37
- inputs = tokenizer(prompt, return_tensors="pt")
38
- with torch.no_grad():
39
- outputs = model.generate(**inputs, max_new_tokens=100)
40
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
41
 
42
- return {"data": [response]}
 
 
 
3
  from fastapi import FastAPI, Request
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from peft import PeftModel
6
+ from huggingface_hub import login
7
+ from typing import Dict, Any
8
 
9
+ # Hugging Face token
10
+ HF_TOKEN = os.getenv("HF_TOKEN")
11
+ if not HF_TOKEN:
12
+ raise ValueError("HF_TOKEN environment variable not set!")
13
 
14
+ # Login to Hugging Face Hub
15
+ login(token=HF_TOKEN)
 
16
 
17
+ # Model IDs
18
+ BASE_MODEL_ID = "google/gemma-1.1-2b-it"
19
+ LORA_MODEL_ID = "programci48/heytak-lora-v1"
20
 
21
+ # Load models with error handling and optimizations
22
+ def load_models() -> Dict[str, Any]:
23
+ try:
24
+ # Load tokenizer
25
+ tokenizer = AutoTokenizer.from_pretrained(
26
+ BASE_MODEL_ID,
27
+ token=HF_TOKEN
28
+ )
29
 
30
+ # Load base model with memory optimization
31
+ base_model = AutoModelForCausalLM.from_pretrained(
32
+ BASE_MODEL_ID,
33
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
34
+ device_map="auto",
35
+ token=HF_TOKEN,
36
+ low_cpu_mem_usage=True,
37
+ offload_folder="offload" # For CPU offloading if needed
38
+ )
39
 
40
+ # Load LoRA adapter
41
+ model = PeftModel.from_pretrained(
42
+ base_model,
43
+ LORA_MODEL_ID,
44
+ token=HF_TOKEN
45
+ )
46
+ model.eval()
47
+
48
+ # Move to CPU if no GPU available
49
+ if not torch.cuda.is_available():
50
+ model = model.to("cpu")
51
+ print("Model moved to CPU")
52
+
53
+ return {
54
+ "tokenizer": tokenizer,
55
+ "model": model
56
+ }
57
+
58
+ except Exception as e:
59
+ raise RuntimeError(f"Model loading failed: {str(e)}")
60
+
61
+ # Initialize models
62
+ models = load_models()
63
+
64
+ # FastAPI app
65
+ app = FastAPI(title="Gemma-LoRA API")
66
 
67
  @app.post("/run/predict")
68
  async def predict(request: Request):
69
+ try:
70
+ data = await request.json()
71
+ prompt = data["data"][0]
72
+
73
+ # Tokenize with truncation
74
+ inputs = models["tokenizer"](
75
+ prompt,
76
+ return_tensors="pt",
77
+ truncation=True,
78
+ max_length=512
79
+ ).to(models["model"].device)
80
+
81
+ # Generate response
82
+ with torch.no_grad():
83
+ outputs = models["model"].generate(
84
+ **inputs,
85
+ max_new_tokens=100,
86
+ do_sample=True,
87
+ temperature=0.7,
88
+ top_p=0.9,
89
+ repetition_penalty=1.1
90
+ )
91
+
92
+ # Decode and clean response
93
+ response = models["tokenizer"].decode(
94
+ outputs[0],
95
+ skip_special_tokens=True
96
+ ).strip()
97
+
98
+ return {"data": [response]}
99
 
100
+ except Exception as e:
101
+ return {"error": str(e)}, 500
 
 
 
102
 
103
+ @app.get("/health")
104
+ async def health_check():
105
+ return {"status": "healthy"}