Zynara commited on
Commit
f3d004f
·
verified ·
1 Parent(s): 654d7b3
Files changed (1) hide show
  1. app.py +61 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import torch
6
+
7
+ # === CONFIG ===
8
+ MODEL_NAME = "microsoft/phi-2" # Replace with phi-4 if available
9
+
10
+ # === INIT ===
11
+ app = FastAPI()
12
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
14
+ model.eval()
15
+
16
+ # === CORS (for browser clients) ===
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"], # Replace with your frontend domain
20
+ allow_methods=["*"],
21
+ allow_headers=["*"],
22
+ )
23
+
24
+ # === Request Schema ===
25
+ class ChatRequest(BaseModel):
26
+ model: str = "phi-4"
27
+ messages: list
28
+
29
+ # === Helper Function ===
30
+ def format_prompt(messages):
31
+ prompt = "You are Billy, a helpful and friendly assistant.\n\n"
32
+ for msg in messages:
33
+ role = msg["role"]
34
+ content = msg["content"]
35
+ if role == "user":
36
+ prompt += f"User: {content}\n"
37
+ elif role == "assistant":
38
+ prompt += f"Billy: {content}\n"
39
+ prompt += "Billy:"
40
+ return prompt
41
+
42
+ # === Chat Endpoint ===
43
+ @app.post("/chat")
44
+ async def chat(req: ChatRequest):
45
+ prompt = format_prompt(req.messages)
46
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
47
+
48
+ with torch.no_grad():
49
+ output = model.generate(
50
+ **inputs,
51
+ max_new_tokens=100,
52
+ do_sample=True,
53
+ temperature=0.7,
54
+ top_p=0.9,
55
+ pad_token_id=tokenizer.eos_token_id
56
+ )
57
+
58
+ decoded = tokenizer.decode(output[0], skip_special_tokens=True)
59
+ response = decoded.split("Billy:")[-1].strip()
60
+
61
+ return {"message": {"content": response}}