hello-ram commited on
Commit
777ec21
·
verified ·
1 Parent(s): 7ed8e50

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -21
app.py CHANGED
@@ -1,39 +1,50 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
4
  import torch
5
 
6
  app = FastAPI()
7
 
8
- # ---- Load your HF model repo ----
9
- MODEL_REPO = "hello-ram/unsolth_gpt.20"
 
 
 
 
10
 
11
  print("Loading tokenizer...")
12
- tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
13
 
14
- print("Loading model...")
15
- model = AutoModelForCausalLM.from_pretrained(
16
- MODEL_REPO,
17
  torch_dtype=torch.float16,
 
 
 
 
 
 
 
18
  device_map="auto"
19
  )
20
 
21
- # ---------- ROUTES -------------
 
22
 
23
  @app.get("/")
24
  async def root():
25
- return {
26
- "message": "🚀 FastAPI MPT Model Running on Hugging Face Spaces",
27
- "endpoints": ["/", "/status", "/generate"]
28
- }
29
 
30
  @app.get("/status")
31
  async def status():
32
  return {
33
  "status": "ok",
34
- "model": MODEL_REPO,
35
- "device": str(model.device),
36
- "torch_dtype": str(model.dtype)
37
  }
38
 
39
 
@@ -45,11 +56,12 @@ class InputText(BaseModel):
45
  async def generate_text(data: InputText):
46
  inputs = tokenizer(data.text, return_tensors="pt").to(model.device)
47
 
48
- output = model.generate(
49
- **inputs,
50
- max_new_tokens=200,
51
- temperature=0.7
52
- )
 
53
 
54
- generated = tokenizer.decode(output[0], skip_special_tokens=True)
55
- return {"response": generated}
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from peft import PeftModel
5
  import torch
6
 
7
  app = FastAPI()
8
 
9
+ # 1. Base model
10
+ BASE_MODEL = "gpt2"
11
+
12
+ # 2. LoRA adapter repo
13
+ LORA_REPO = "hello-ram/unsolth_gpt.20"
14
+
15
 
16
  print("Loading tokenizer...")
17
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
18
 
19
+ print("Loading base model...")
20
+ base_model = AutoModelForCausalLM.from_pretrained(
21
+ BASE_MODEL,
22
  torch_dtype=torch.float16,
23
+ device_map="auto",
24
+ )
25
+
26
+ print("Applying LoRA adapter...")
27
+ model = PeftModel.from_pretrained(
28
+ base_model,
29
+ LORA_REPO,
30
  device_map="auto"
31
  )
32
 
33
+ model.eval()
34
+
35
 
36
  @app.get("/")
37
  async def root():
38
+ return {"msg": "LoRA model running", "endpoints": ["/status", "/generate"]}
39
+
 
 
40
 
41
  @app.get("/status")
42
  async def status():
43
  return {
44
  "status": "ok",
45
+ "base_model": BASE_MODEL,
46
+ "lora_model": LORA_REPO,
47
+ "device": str(model.device)
48
  }
49
 
50
 
 
56
  async def generate_text(data: InputText):
57
  inputs = tokenizer(data.text, return_tensors="pt").to(model.device)
58
 
59
+ with torch.no_grad():
60
+ output = model.generate(
61
+ **inputs,
62
+ max_new_tokens=200,
63
+ temperature=0.7
64
+ )
65
 
66
+ text = tokenizer.decode(output[0], skip_special_tokens=True)
67
+ return {"response": text}