will702 commited on
Commit
b395362
·
verified ·
1 Parent(s): e98af63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -126
app.py CHANGED
@@ -1,126 +1,168 @@
1
- import json
2
- import os
3
- import re
4
- from contextlib import asynccontextmanager
5
-
6
- from fastapi import FastAPI, HTTPException, Request
7
- from pydantic import BaseModel
8
- from transformers import AutoModelForCausalLM, AutoTokenizer
9
- import torch
10
-
11
- MODEL_NAME = "Qwen/Qwen3.5-0.8B"
12
- API_KEY = os.getenv("API_KEY")
13
-
14
- tokenizer = None
15
- model = None
16
-
17
-
18
- @asynccontextmanager
19
- async def lifespan(app: FastAPI):
20
- global tokenizer, model
21
- print(f"Loading model: {MODEL_NAME}")
22
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
23
- model = AutoModelForCausalLM.from_pretrained(
24
- MODEL_NAME,
25
- torch_dtype=torch.float32, # CPU requires float32
26
- device_map="cpu",
27
- )
28
- model.eval()
29
- print("Model loaded.")
30
- yield
31
-
32
-
33
- app = FastAPI(title="StockPro Sentiment", lifespan=lifespan)
34
-
35
-
36
- class PredictRequest(BaseModel):
37
- texts: list[str]
38
-
39
-
40
- SYSTEM_PROMPT = """You are a financial sentiment analyzer for Indonesian stock market news.
41
- Analyze each headline and return ONLY a JSON array with no extra text.
42
- Each item must have: "text" (original), "sentiment" ("positive", "negative", or "neutral"), "score" (0.0-1.0 confidence).
43
- Respond only with the JSON array, no markdown, no explanation."""
44
-
45
-
46
- def build_prompt(texts: list[str]) -> str:
47
- headlines = "\n".join(f"{i+1}. {t}" for i, t in enumerate(texts))
48
- return f"Analyze sentiment for these Indonesian stock headlines:\n{headlines}"
49
-
50
-
51
- def parse_response(raw: str, texts: list[str]) -> list[dict]:
52
- # Strip thinking tags if present
53
- raw = re.sub(r"<think>.*?</think>", "", raw, flags=re.DOTALL).strip()
54
- # Extract JSON array
55
- match = re.search(r"\[.*\]", raw, re.DOTALL)
56
- if match:
57
- try:
58
- parsed = json.loads(match.group())
59
- if isinstance(parsed, list) and len(parsed) == len(texts):
60
- return parsed
61
- except json.JSONDecodeError:
62
- pass
63
- # Fallback: neutral for all
64
- return [{"text": t, "sentiment": "neutral", "score": 0.5} for t in texts]
65
-
66
-
67
- @app.post("/predict")
68
- async def predict(body: PredictRequest, request: Request):
69
- if API_KEY:
70
- key = request.headers.get("X-API-Key")
71
- if key != API_KEY:
72
- raise HTTPException(status_code=401, detail="Invalid API key")
73
-
74
- texts = body.texts
75
- if not texts:
76
- raise HTTPException(status_code=400, detail="texts must not be empty")
77
- if len(texts) > 20:
78
- raise HTTPException(status_code=400, detail="Maximum 20 texts per request")
79
-
80
- if model is None or tokenizer is None:
81
- raise HTTPException(status_code=503, detail="Model not loaded yet")
82
-
83
- messages = [
84
- {"role": "system", "content": SYSTEM_PROMPT},
85
- {"role": "user", "content": build_prompt(texts)},
86
- ]
87
-
88
- text_input = tokenizer.apply_chat_template(
89
- messages,
90
- tokenize=False,
91
- add_generation_prompt=True,
92
- enable_thinking=False, # Disable thinking for faster response
93
- )
94
- inputs = tokenizer(text_input, return_tensors="pt")
95
-
96
- with torch.no_grad():
97
- outputs = model.generate(
98
- **inputs,
99
- max_new_tokens=512,
100
- do_sample=False,
101
- pad_token_id=tokenizer.eos_token_id,
102
- )
103
-
104
- generated = outputs[0][inputs["input_ids"].shape[1]:]
105
- raw = tokenizer.decode(generated, skip_special_tokens=True)
106
-
107
- results = parse_response(raw, texts)
108
-
109
- # Normalize output format
110
- normalized = []
111
- for r in results:
112
- sentiment = str(r.get("sentiment", "neutral")).lower()
113
- if sentiment not in ("positive", "negative", "neutral"):
114
- sentiment = "neutral"
115
- normalized.append({
116
- "text": r.get("text", ""),
117
- "sentiment": sentiment,
118
- "score": round(float(r.get("score", 0.5)), 4),
119
- })
120
-
121
- return {"results": normalized}
122
-
123
-
124
- @app.get("/health")
125
- def health():
126
- return {"status": "ok", "model_loaded": model is not None}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 6 from fastapi import FastAPI, HTTPException, Request
2
+ 7 from pydantic import BaseModel
3
+ 8 -from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ 9 -import torch
5
+ 8
6
+ 9 MODEL_NAME = "Qwen/Qwen3.5-0.8B"
7
+ 10 API_KEY = os.getenv("API_KEY")
8
+ 11 +HF_TOKEN = os.getenv("HF_TOKEN")
9
+ 12
10
+ 14 -tokenizer = None
11
+ 15 -model = None
12
+ 13 +# Will hold either InferenceClient or local model+tokenizer
13
+ 14 +inference_client = None
14
+ 15 +local_model = None
15
+ 16 +local_tokenizer = None
16
+ 17
17
+ 18
18
+ 19 @asynccontextmanager
19
+ 20 async def lifespan(app: FastAPI):
20
+ 20 - global tokenizer, model
21
+ 21 - print(f"Loading model: {MODEL_NAME}")
22
+ 22 - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
23
+ 23 - model = AutoModelForCausalLM.from_pretrained(
24
+ 24 - MODEL_NAME,
25
+ 25 - torch_dtype=torch.float32, # CPU requires float32
26
+ 26 - device_map="cpu",
27
+ 27 - )
28
+ 28 - model.eval()
29
+ 29 - print("Model loaded.")
30
+ 21 + global inference_client, local_model, local_tokenizer
31
+ 22 +
32
+ 23 + if HF_TOKEN:
33
+ 24 + # Option 1: HF Inference API (GPU-backed, fast)
34
+ 25 + print("HF_TOKEN found — using HF Inference API")
35
+ 26 + from huggingface_hub import InferenceClient
36
+ 27 + inference_client = InferenceClient(model=MODEL_NAME, token=HF_TOKEN)
37
+ 28 + print("Inference client ready.")
38
+ 29 + else:
39
+ 30 + # Option 2: Local model with INT8 quantization (CPU fallback)
40
+ 31 + print("No HF_TOKEN loading model locally with INT8 quantization")
41
+ 32 + import torch
42
+ 33 + from transformers import AutoModelForCausalLM, AutoTokenizer
43
+ 34 + import torch.quantization
44
+ 35 +
45
+ 36 + local_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
46
+ 37 + model = AutoModelForCausalLM.from_pretrained(
47
+ 38 + MODEL_NAME,
48
+ 39 + torch_dtype=torch.float32,
49
+ 40 + device_map="cpu",
50
+ 41 + )
51
+ 42 + # Apply dynamic INT8 quantization for faster CPU inference
52
+ 43 + local_model = torch.quantization.quantize_dynamic(
53
+ 44 + model, {torch.nn.Linear}, dtype=torch.qint8
54
+ 45 + )
55
+ 46 + local_model.eval()
56
+ 47 + print("Local INT8 model ready.")
57
+ 48 +
58
+ 49 yield
59
+ 50
60
+ 51
61
+ ...
62
+ 68
63
+ 69
64
+ 70 def parse_response(raw: str, texts: list[str]) -> list[dict]:
65
+ 52 - # Strip thinking tags if present
66
+ 71 raw = re.sub(r"<think>.*?</think>", "", raw, flags=re.DOTALL).strip()
67
+ 54 - # Extract JSON array
68
+ 72 match = re.search(r"\[.*\]", raw, re.DOTALL)
69
+ 73 if match:
70
+ 74 try:
71
+ ...
72
+ 77 return parsed
73
+ 78 except json.JSONDecodeError:
74
+ 79 pass
75
+ 63 - # Fallback: neutral for all
76
+ 80 return [{"text": t, "sentiment": "neutral", "score": 0.5} for t in texts]
77
+ 81
78
+ 82
79
+ 67 -@app.post("/predict")
80
+ 68 -async def predict(body: PredictRequest, request: Request):
81
+ 69 - if API_KEY:
82
+ 70 - key = request.headers.get("X-API-Key")
83
+ 71 - if key != API_KEY:
84
+ 72 - raise HTTPException(status_code=401, detail="Invalid API key")
85
+ 83 +def run_hf_api(texts: list[str]) -> str:
86
+ 84 + messages = [
87
+ 85 + {"role": "system", "content": SYSTEM_PROMPT},
88
+ 86 + {"role": "user", "content": build_prompt(texts)},
89
+ 87 + ]
90
+ 88 + response = inference_client.chat_completion(
91
+ 89 + messages=messages,
92
+ 90 + max_tokens=512,
93
+ 91 + temperature=0.1,
94
+ 92 + )
95
+ 93 + return response.choices[0].message.content
96
+ 94
97
+ 74 - texts = body.texts
98
+ 75 - if not texts:
99
+ 76 - raise HTTPException(status_code=400, detail="texts must not be empty")
100
+ 77 - if len(texts) > 20:
101
+ 78 - raise HTTPException(status_code=400, detail="Maximum 20 texts per request")
102
+ 95
103
+ 80 - if model is None or tokenizer is None:
104
+ 81 - raise HTTPException(status_code=503, detail="Model not loaded yet")
105
+ 82 -
106
+ 96 +def run_local(texts: list[str]) -> str:
107
+ 97 + import torch
108
+ 98 messages = [
109
+ 99 {"role": "system", "content": SYSTEM_PROMPT},
110
+ 100 {"role": "user", "content": build_prompt(texts)},
111
+ 101 ]
112
+ 87 -
113
+ 88 - text_input = tokenizer.apply_chat_template(
114
+ 102 + text_input = local_tokenizer.apply_chat_template(
115
+ 103 messages,
116
+ 104 tokenize=False,
117
+ 105 add_generation_prompt=True,
118
+ 92 - enable_thinking=False, # Disable thinking for faster response
119
+ 106 + enable_thinking=False,
120
+ 107 )
121
+ 94 - inputs = tokenizer(text_input, return_tensors="pt")
122
+ 95 -
123
+ 108 + inputs = local_tokenizer(text_input, return_tensors="pt")
124
+ 109 with torch.no_grad():
125
+ 97 - outputs = model.generate(
126
+ 110 + outputs = local_model.generate(
127
+ 111 **inputs,
128
+ 112 max_new_tokens=512,
129
+ 113 do_sample=False,
130
+ 101 - pad_token_id=tokenizer.eos_token_id,
131
+ 114 + pad_token_id=local_tokenizer.eos_token_id,
132
+ 115 )
133
+ 103 -
134
+ 116 generated = outputs[0][inputs["input_ids"].shape[1]:]
135
+ 105 - raw = tokenizer.decode(generated, skip_special_tokens=True)
136
+ 117 + return local_tokenizer.decode(generated, skip_special_tokens=True)
137
+ 118
138
+ 119 +
139
+ 120 +@app.post("/predict")
140
+ 121 +async def predict(body: PredictRequest, request: Request):
141
+ 122 + if API_KEY:
142
+ 123 + key = request.headers.get("X-API-Key")
143
+ 124 + if key != API_KEY:
144
+ 125 + raise HTTPException(status_code=401, detail="Invalid API key")
145
+ 126 +
146
+ 127 + texts = body.texts
147
+ 128 + if not texts:
148
+ 129 + raise HTTPException(status_code=400, detail="texts must not be empty")
149
+ 130 + if len(texts) > 20:
150
+ 131 + raise HTTPException(status_code=400, detail="Maximum 20 texts per request")
151
+ 132 +
152
+ 133 + if inference_client is None and local_model is None:
153
+ 134 + raise HTTPException(status_code=503, detail="Model not loaded yet")
154
+ 135 +
155
+ 136 + raw = run_hf_api(texts) if inference_client else run_local(texts)
156
+ 137 results = parse_response(raw, texts)
157
+ 138
158
+ 109 - # Normalize output format
159
+ 139 normalized = []
160
+ 140 for r in results:
161
+ 141 sentiment = str(r.get("sentiment", "neutral")).lower()
162
+ ...
163
+ 152
164
+ 153 @app.get("/health")
165
+ 154 def health():
166
+ 126 - return {"status": "ok", "model_loaded": model is not None}
167
+ 155 + mode = "hf_api" if inference_client else "local_int8" if local_model else "not_loaded"
168
+ 156 + return {"status": "ok", "mode": mode}