gbakidz commited on
Commit
bf50024
·
verified ·
1 Parent(s): d0d5289

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import torch
5
+ import requests
6
+ from bs4 import BeautifulSoup
7
+
8
+ app = FastAPI()
9
+
10
+ MODEL_NAME = "microsoft/phi-2"
11
+
12
+ print("Loading Phi-2...")
13
+
14
+ torch.set_num_threads(2)
15
+
16
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
17
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
18
+ model.to("cpu")
19
+
20
+ print("Model loaded!")
21
+
22
+ # -------- REQUEST SCHEMA --------
23
+ class RequestData(BaseModel):
24
+ prompt: str
25
+ use_search: bool = False
26
+
27
+
28
+ # -------- WEB SEARCH FUNCTION --------
29
+ def search_web(query):
30
+ url = f"https://duckduckgo.com/html/?q={query}"
31
+ headers = {"User-Agent": "Mozilla/5.0"}
32
+
33
+ response = requests.get(url, headers=headers)
34
+ soup = BeautifulSoup(response.text, "html.parser")
35
+
36
+ results = []
37
+ for a in soup.select("a.result__a"):
38
+ results.append(a.get_text())
39
+
40
+ return " ".join(results[:5])
41
+
42
+
43
+ # -------- GENERATE FUNCTION --------
44
+ def generate_text(prompt):
45
+ formatted = f"Instruct: {prompt}\nOutput:"
46
+
47
+ inputs = tokenizer(formatted, return_tensors="pt")
48
+
49
+ outputs = model.generate(
50
+ inputs["input_ids"],
51
+ max_new_tokens=60,
52
+ temperature=0.7
53
+ )
54
+
55
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
56
+ return result.split("Output:")[-1].strip()
57
+
58
+
59
+ # -------- API ENDPOINT --------
60
+ @app.post("/generate")
61
+ def generate(data: RequestData):
62
+ prompt = data.prompt
63
+
64
+ if data.use_search:
65
+ web_data = search_web(prompt)
66
+ prompt = f"{prompt}\n\nWeb Info: {web_data}"
67
+
68
+ response = generate_text(prompt)
69
+
70
+ return {
71
+ "response": response
72
+ }