sahil239 commited on
Commit
4283167
·
verified ·
1 Parent(s): 702b01a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from peft import PeftModel
5
+ import torch
6
+
7
+ app = FastAPI()
8
+
9
+ # === MODEL ===
10
+ MODEL_REPO = "sahil239/falcon-lora-chatbot" # replace with your HF repo
11
+ BASE_MODEL = "tiiuae/falcon-rw-1b"
12
+
13
+ # === Load tokenizer ===
14
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
15
+ tokenizer.pad_token = tokenizer.eos_token # required to avoid padding error
16
+
17
+ # === Load base model and merge LoRA ===
18
+ base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, trust_remote_code=True)
19
+ model = PeftModel.from_pretrained(base_model, MODEL_REPO)
20
+ model.eval()
21
+
22
+ # === Move to GPU if available ===
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ model.to(device)
25
+
26
+ # === Request Schema ===
27
+ class PromptRequest(BaseModel):
28
+ prompt: str
29
+ max_new_tokens: int = 200
30
+ temperature: float = 0.7
31
+ top_p: float = 0.95
32
+
33
+ @app.get("/")
34
+ def health_check():
35
+ return {"status": "running"}
36
+
37
+ @app.post("/generate")
38
+ async def generate_text(req: PromptRequest):
39
+ inputs = tokenizer(
40
+ req.prompt,
41
+ return_tensors="pt",
42
+ padding=True,
43
+ truncation=True,
44
+ max_length=200
45
+ )
46
+ inputs = {k: v.to(device) for k, v in inputs.items()}
47
+
48
+ with torch.no_grad():
49
+ outputs = model.generate(
50
+ input_ids=inputs["input_ids"],
51
+ attention_mask=inputs["attention_mask"],
52
+ max_new_tokens=req.max_new_tokens,
53
+ temperature=req.temperature,
54
+ top_p=req.top_p,
55
+ do_sample=True,
56
+ pad_token_id=tokenizer.eos_token_id,
57
+ eos_token_id=tokenizer.eos_token_id, # 🚨 Helps stop when sentence is "done"
58
+ repetition_penalty=1.2, # 🚫 Penalizes repeating phrases
59
+ no_repeat_ngram_size=3
60
+ )
61
+
62
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
63
+ return {"response": generated_text[len(req.prompt):].strip()}