zerovic commited on
Commit
b28fff6
·
verified ·
1 Parent(s): 691e1ca

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from fastapi import FastAPI
4
+ from pydantic import BaseModel
5
+
6
+ app = FastAPI()
7
+
8
+ MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11
+
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ MODEL_NAME,
14
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
15
+ )
16
+
17
+ class RequestData(BaseModel):
18
+ inputs: str
19
+
20
+
21
+ def generate_text(prompt):
22
+
23
+ # ✅ Proper chat formatting (THIS IS THE FIX)
24
+ formatted_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n"
25
+
26
+ inputs = tokenizer(formatted_prompt, return_tensors="pt")
27
+
28
+ with torch.no_grad():
29
+ output = model.generate(
30
+ **inputs,
31
+ max_new_tokens=200,
32
+ do_sample=True,
33
+ temperature=0.7,
34
+ top_p=0.9,
35
+ repetition_penalty=1.1,
36
+ pad_token_id=tokenizer.eos_token_id
37
+ )
38
+
39
+ result = tokenizer.decode(output[0], skip_special_tokens=True)
40
+
41
+ # ✅ Extract only assistant response
42
+ if "<|assistant|>" in result:
43
+ result = result.split("<|assistant|>")[-1]
44
+
45
+ return result.strip()
46
+
47
+
48
+ @app.post("/generate")
49
+ async def generate(request: RequestData):
50
+
51
+ text = generate_text(request.inputs)
52
+
53
+ return {
54
+ "data": [text]
55
+ }