ClergeF commited on
Commit
3425760
·
verified ·
1 Parent(s): be20f8f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ from fastapi import FastAPI
4
+ from pydantic import BaseModel
5
+ from huggingface_hub import hf_hub_download
6
+ from transformers import AutoTokenizer, AutoModel
7
+ import torch
8
+
9
+ # -----------------------------
10
+ # CONFIG
11
+ # -----------------------------
12
+ EMBEDDER_REPO = "ClergeF/MVT-embedder"
13
+ MODEL_REPO = "ClergeF/impact-model"
14
+ MODEL_FILE = "impact.json"
15
+
16
+ # -----------------------------
17
+ # SAFE EMBEDDER LOADER
18
+ # -----------------------------
19
+ def load_safe_embedder(repo_id: str):
20
+ print(f"Loading embedder from {repo_id} (safe mode)...")
21
+ tokenizer = AutoTokenizer.from_pretrained(repo_id)
22
+ model = AutoModel.from_pretrained(repo_id)
23
+
24
+ def embed_fn(texts):
25
+ tokens = tokenizer(
26
+ texts,
27
+ padding=True,
28
+ truncation=True,
29
+ max_length=256,
30
+ return_tensors="pt"
31
+ )
32
+ with torch.no_grad():
33
+ outputs = model(**tokens).last_hidden_state
34
+ embeddings = outputs.mean(dim=1).numpy()
35
+ return embeddings
36
+
37
+ return embed_fn
38
+
39
+ # Load embedder
40
+ embed_fn = load_safe_embedder(EMBEDDER_REPO)
41
+
42
+ # -----------------------------
43
+ # LOAD THE IMPACT MODEL
44
+ # -----------------------------
45
+ print("Loading impact model...")
46
+ path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
47
+ with open(path, "r") as f:
48
+ impact_model = json.load(f)
49
+
50
+ # -----------------------------
51
+ # HELPERS
52
+ # -----------------------------
53
+ def embed(text: str):
54
+ return embed_fn([text])[0]
55
+
56
+ def linear_predict(model_json, vec):
57
+ coef = np.array(model_json["coef"])
58
+ intercept = np.array(model_json["intercept"])
59
+ return float(np.dot(coef, vec) + intercept)
60
+
61
+ # -----------------------------
62
+ # FASTAPI
63
+ # -----------------------------
64
+ app = FastAPI(title="Impact Model API")
65
+
66
+ class Input(BaseModel):
67
+ text: str
68
+
69
+ @app.post("/predict")
70
+ def predict(payload: Input):
71
+ vec = embed(payload.text)
72
+ impact = linear_predict(impact_model, vec)
73
+
74
+ return {
75
+ "input": payload.text,
76
+ "impact_score": impact
77
+ }