fansa34 commited on
Commit
ab758b3
·
verified ·
1 Parent(s): 81ee137

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -14
app.py CHANGED
@@ -1,23 +1,62 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
 
 
 
 
5
 
6
  app = FastAPI()
7
 
8
- # Replace with your actual model repo
9
- MODEL_NAME = "fansa34/finetunedModel"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- # Load model and tokenizer
12
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
 
 
 
 
14
 
15
- class AskRequest(BaseModel):
 
 
 
 
 
16
  question: str
 
 
17
 
18
  @app.post("/ask")
19
- def ask(req: AskRequest):
20
- inputs = tokenizer(req.question, return_tensors="pt")
21
- outputs = model.generate(**inputs, max_new_tokens=100)
22
- answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
23
- return {"answer": answer}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
+ from peft import PeftModel
4
+ from fastapi import FastAPI, Request
5
+ from pydantic import BaseModel
6
 
7
  app = FastAPI()
8
 
9
+ # Configs
10
+ BASE_MODEL = "mistralai/Mistral-7B-v0.1"
11
+ ADAPTER_MODEL = "fansa34/finetunedModel"
12
+
13
+ # Quantization for 4-bit loading (QLoRA)
14
+ quant_config = BitsAndBytesConfig(
15
+ load_in_4bit=True,
16
+ bnb_4bit_use_double_quant=True,
17
+ bnb_4bit_quant_type="nf4",
18
+ bnb_4bit_compute_dtype=torch.float16,
19
+ )
20
+
21
+ # Load tokenizer and base model
22
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
23
+ if tokenizer.pad_token is None:
24
+ tokenizer.pad_token = tokenizer.eos_token
25
 
26
+ base_model = AutoModelForCausalLM.from_pretrained(
27
+ BASE_MODEL,
28
+ device_map="auto",
29
+ torch_dtype=torch.float16,
30
+ trust_remote_code=True,
31
+ quantization_config=quant_config,
32
+ )
33
 
34
+ # Load LoRA adapter
35
+ model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL)
36
+ model.eval()
37
+
38
+ # Request schema
39
+ class QueryRequest(BaseModel):
40
  question: str
41
+ max_new_tokens: int = 200
42
+ temperature: float = 0.6
43
 
44
  @app.post("/ask")
45
+ async def ask(req: QueryRequest):
46
+ prompt = f"Question: {req.question}\nAnswer:"
47
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
48
+
49
+ with torch.no_grad():
50
+ output = model.generate(
51
+ **inputs,
52
+ max_new_tokens=req.max_new_tokens,
53
+ temperature=req.temperature,
54
+ do_sample=True,
55
+ top_p=0.9,
56
+ top_k=50,
57
+ repetition_penalty=1.1,
58
+ pad_token_id=tokenizer.pad_token_id
59
+ )
60
+
61
+ response = tokenizer.decode(output[0], skip_special_tokens=True).split("Answer:")[-1].strip()
62
+ return {"question": req.question, "answer": response}