will702 commited on
Commit
84f612c
·
verified ·
1 Parent(s): 9f01dff

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -168
app.py CHANGED
@@ -1,168 +1,156 @@
1
- 6 from fastapi import FastAPI, HTTPException, Request
2
- 7 from pydantic import BaseModel
3
- 8 -from transformers import AutoModelForCausalLM, AutoTokenizer
4
- 9 -import torch
5
- 8
6
- 9 MODEL_NAME = "Qwen/Qwen3.5-0.8B"
7
- 10 API_KEY = os.getenv("API_KEY")
8
- 11 +HF_TOKEN = os.getenv("HF_TOKEN")
9
- 12
10
- 14 -tokenizer = None
11
- 15 -model = None
12
- 13 +# Will hold either InferenceClient or local model+tokenizer
13
- 14 +inference_client = None
14
- 15 +local_model = None
15
- 16 +local_tokenizer = None
16
- 17
17
- 18
18
- 19 @asynccontextmanager
19
- 20 async def lifespan(app: FastAPI):
20
- 20 - global tokenizer, model
21
- 21 - print(f"Loading model: {MODEL_NAME}")
22
- 22 - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
23
- 23 - model = AutoModelForCausalLM.from_pretrained(
24
- 24 - MODEL_NAME,
25
- 25 - torch_dtype=torch.float32, # CPU requires float32
26
- 26 - device_map="cpu",
27
- 27 - )
28
- 28 - model.eval()
29
- 29 - print("Model loaded.")
30
- 21 + global inference_client, local_model, local_tokenizer
31
- 22 +
32
- 23 + if HF_TOKEN:
33
- 24 + # Option 1: HF Inference API (GPU-backed, fast)
34
- 25 + print("HF_TOKEN found — using HF Inference API")
35
- 26 + from huggingface_hub import InferenceClient
36
- 27 + inference_client = InferenceClient(model=MODEL_NAME, token=HF_TOKEN)
37
- 28 + print("Inference client ready.")
38
- 29 + else:
39
- 30 + # Option 2: Local model with INT8 quantization (CPU fallback)
40
- 31 + print("No HF_TOKEN — loading model locally with INT8 quantization")
41
- 32 + import torch
42
- 33 + from transformers import AutoModelForCausalLM, AutoTokenizer
43
- 34 + import torch.quantization
44
- 35 +
45
- 36 + local_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
46
- 37 + model = AutoModelForCausalLM.from_pretrained(
47
- 38 + MODEL_NAME,
48
- 39 + torch_dtype=torch.float32,
49
- 40 + device_map="cpu",
50
- 41 + )
51
- 42 + # Apply dynamic INT8 quantization for faster CPU inference
52
- 43 + local_model = torch.quantization.quantize_dynamic(
53
- 44 + model, {torch.nn.Linear}, dtype=torch.qint8
54
- 45 + )
55
- 46 + local_model.eval()
56
- 47 + print("Local INT8 model ready.")
57
- 48 +
58
- 49 yield
59
- 50
60
- 51
61
- ...
62
- 68
63
- 69
64
- 70 def parse_response(raw: str, texts: list[str]) -> list[dict]:
65
- 52 - # Strip thinking tags if present
66
- 71 raw = re.sub(r"<think>.*?</think>", "", raw, flags=re.DOTALL).strip()
67
- 54 - # Extract JSON array
68
- 72 match = re.search(r"\[.*\]", raw, re.DOTALL)
69
- 73 if match:
70
- 74 try:
71
- ...
72
- 77 return parsed
73
- 78 except json.JSONDecodeError:
74
- 79 pass
75
- 63 - # Fallback: neutral for all
76
- 80 return [{"text": t, "sentiment": "neutral", "score": 0.5} for t in texts]
77
- 81
78
- 82
79
- 67 -@app.post("/predict")
80
- 68 -async def predict(body: PredictRequest, request: Request):
81
- 69 - if API_KEY:
82
- 70 - key = request.headers.get("X-API-Key")
83
- 71 - if key != API_KEY:
84
- 72 - raise HTTPException(status_code=401, detail="Invalid API key")
85
- 83 +def run_hf_api(texts: list[str]) -> str:
86
- 84 + messages = [
87
- 85 + {"role": "system", "content": SYSTEM_PROMPT},
88
- 86 + {"role": "user", "content": build_prompt(texts)},
89
- 87 + ]
90
- 88 + response = inference_client.chat_completion(
91
- 89 + messages=messages,
92
- 90 + max_tokens=512,
93
- 91 + temperature=0.1,
94
- 92 + )
95
- 93 + return response.choices[0].message.content
96
- 94
97
- 74 - texts = body.texts
98
- 75 - if not texts:
99
- 76 - raise HTTPException(status_code=400, detail="texts must not be empty")
100
- 77 - if len(texts) > 20:
101
- 78 - raise HTTPException(status_code=400, detail="Maximum 20 texts per request")
102
- 95
103
- 80 - if model is None or tokenizer is None:
104
- 81 - raise HTTPException(status_code=503, detail="Model not loaded yet")
105
- 82 -
106
- 96 +def run_local(texts: list[str]) -> str:
107
- 97 + import torch
108
- 98 messages = [
109
- 99 {"role": "system", "content": SYSTEM_PROMPT},
110
- 100 {"role": "user", "content": build_prompt(texts)},
111
- 101 ]
112
- 87 -
113
- 88 - text_input = tokenizer.apply_chat_template(
114
- 102 + text_input = local_tokenizer.apply_chat_template(
115
- 103 messages,
116
- 104 tokenize=False,
117
- 105 add_generation_prompt=True,
118
- 92 - enable_thinking=False, # Disable thinking for faster response
119
- 106 + enable_thinking=False,
120
- 107 )
121
- 94 - inputs = tokenizer(text_input, return_tensors="pt")
122
- 95 -
123
- 108 + inputs = local_tokenizer(text_input, return_tensors="pt")
124
- 109 with torch.no_grad():
125
- 97 - outputs = model.generate(
126
- 110 + outputs = local_model.generate(
127
- 111 **inputs,
128
- 112 max_new_tokens=512,
129
- 113 do_sample=False,
130
- 101 - pad_token_id=tokenizer.eos_token_id,
131
- 114 + pad_token_id=local_tokenizer.eos_token_id,
132
- 115 )
133
- 103 -
134
- 116 generated = outputs[0][inputs["input_ids"].shape[1]:]
135
- 105 - raw = tokenizer.decode(generated, skip_special_tokens=True)
136
- 117 + return local_tokenizer.decode(generated, skip_special_tokens=True)
137
- 118
138
- 119 +
139
- 120 +@app.post("/predict")
140
- 121 +async def predict(body: PredictRequest, request: Request):
141
- 122 + if API_KEY:
142
- 123 + key = request.headers.get("X-API-Key")
143
- 124 + if key != API_KEY:
144
- 125 + raise HTTPException(status_code=401, detail="Invalid API key")
145
- 126 +
146
- 127 + texts = body.texts
147
- 128 + if not texts:
148
- 129 + raise HTTPException(status_code=400, detail="texts must not be empty")
149
- 130 + if len(texts) > 20:
150
- 131 + raise HTTPException(status_code=400, detail="Maximum 20 texts per request")
151
- 132 +
152
- 133 + if inference_client is None and local_model is None:
153
- 134 + raise HTTPException(status_code=503, detail="Model not loaded yet")
154
- 135 +
155
- 136 + raw = run_hf_api(texts) if inference_client else run_local(texts)
156
- 137 results = parse_response(raw, texts)
157
- 138
158
- 109 - # Normalize output format
159
- 139 normalized = []
160
- 140 for r in results:
161
- 141 sentiment = str(r.get("sentiment", "neutral")).lower()
162
- ...
163
- 152
164
- 153 @app.get("/health")
165
- 154 def health():
166
- 126 - return {"status": "ok", "model_loaded": model is not None}
167
- 155 + mode = "hf_api" if inference_client else "local_int8" if local_model else "not_loaded"
168
- 156 + return {"status": "ok", "mode": mode}
 
1
+ import json
2
+ import os
3
+ import re
4
+ from contextlib import asynccontextmanager
5
+
6
+ from fastapi import FastAPI, HTTPException, Request
7
+ from pydantic import BaseModel
8
+
9
+ MODEL_NAME = "Qwen/Qwen3.5-0.8B"
10
+ API_KEY = os.getenv("API_KEY")
11
+ HF_TOKEN = os.getenv("HF_TOKEN")
12
+
13
+ # Will hold either InferenceClient or local model+tokenizer
14
+ inference_client = None
15
+ local_model = None
16
+ local_tokenizer = None
17
+
18
+
19
+ @asynccontextmanager
20
+ async def lifespan(app: FastAPI):
21
+ global inference_client, local_model, local_tokenizer
22
+
23
+ if HF_TOKEN:
24
+ # Option 1: HF Inference API (GPU-backed, fast)
25
+ print("HF_TOKEN found using HF Inference API")
26
+ from huggingface_hub import InferenceClient
27
+ inference_client = InferenceClient(model=MODEL_NAME, token=HF_TOKEN)
28
+ print("Inference client ready.")
29
+ else:
30
+ # Option 2: Local model with INT8 quantization (CPU fallback)
31
+ print("No HF_TOKEN — loading model locally with INT8 quantization")
32
+ import torch
33
+ from transformers import AutoModelForCausalLM, AutoTokenizer
34
+ import torch.quantization
35
+
36
+ local_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
37
+ model = AutoModelForCausalLM.from_pretrained(
38
+ MODEL_NAME,
39
+ torch_dtype=torch.float32,
40
+ device_map="cpu",
41
+ )
42
+ # Apply dynamic INT8 quantization for faster CPU inference
43
+ local_model = torch.quantization.quantize_dynamic(
44
+ model, {torch.nn.Linear}, dtype=torch.qint8
45
+ )
46
+ local_model.eval()
47
+ print("Local INT8 model ready.")
48
+
49
+ yield
50
+
51
+
52
+ app = FastAPI(title="StockPro Sentiment", lifespan=lifespan)
53
+
54
+
55
+ class PredictRequest(BaseModel):
56
+ texts: list[str]
57
+
58
+
59
+ SYSTEM_PROMPT = """You are a financial sentiment analyzer for Indonesian stock market news.
60
+ Analyze each headline and return ONLY a JSON array with no extra text.
61
+ Each item must have: "text" (original), "sentiment" ("positive", "negative", or "neutral"), "score" (0.0-1.0 confidence).
62
+ Respond only with the JSON array, no markdown, no explanation."""
63
+
64
+
65
+ def build_prompt(texts: list[str]) -> str:
66
+ headlines = "\n".join(f"{i+1}. {t}" for i, t in enumerate(texts))
67
+ return f"Analyze sentiment for these Indonesian stock headlines:\n{headlines}"
68
+
69
+
70
+ def parse_response(raw: str, texts: list[str]) -> list[dict]:
71
+ raw = re.sub(r"<think>.*?</think>", "", raw, flags=re.DOTALL).strip()
72
+ match = re.search(r"\[.*\]", raw, re.DOTALL)
73
+ if match:
74
+ try:
75
+ parsed = json.loads(match.group())
76
+ if isinstance(parsed, list) and len(parsed) == len(texts):
77
+ return parsed
78
+ except json.JSONDecodeError:
79
+ pass
80
+ return [{"text": t, "sentiment": "neutral", "score": 0.5} for t in texts]
81
+
82
+
83
+ def run_hf_api(texts: list[str]) -> str:
84
+ messages = [
85
+ {"role": "system", "content": SYSTEM_PROMPT},
86
+ {"role": "user", "content": build_prompt(texts)},
87
+ ]
88
+ response = inference_client.chat_completion(
89
+ messages=messages,
90
+ max_tokens=512,
91
+ temperature=0.1,
92
+ )
93
+ return response.choices[0].message.content
94
+
95
+
96
+ def run_local(texts: list[str]) -> str:
97
+ import torch
98
+ messages = [
99
+ {"role": "system", "content": SYSTEM_PROMPT},
100
+ {"role": "user", "content": build_prompt(texts)},
101
+ ]
102
+ text_input = local_tokenizer.apply_chat_template(
103
+ messages,
104
+ tokenize=False,
105
+ add_generation_prompt=True,
106
+ enable_thinking=False,
107
+ )
108
+ inputs = local_tokenizer(text_input, return_tensors="pt")
109
+ with torch.no_grad():
110
+ outputs = local_model.generate(
111
+ **inputs,
112
+ max_new_tokens=512,
113
+ do_sample=False,
114
+ pad_token_id=local_tokenizer.eos_token_id,
115
+ )
116
+ generated = outputs[0][inputs["input_ids"].shape[1]:]
117
+ return local_tokenizer.decode(generated, skip_special_tokens=True)
118
+
119
+
120
+ @app.post("/predict")
121
+ async def predict(body: PredictRequest, request: Request):
122
+ if API_KEY:
123
+ key = request.headers.get("X-API-Key")
124
+ if key != API_KEY:
125
+ raise HTTPException(status_code=401, detail="Invalid API key")
126
+
127
+ texts = body.texts
128
+ if not texts:
129
+ raise HTTPException(status_code=400, detail="texts must not be empty")
130
+ if len(texts) > 20:
131
+ raise HTTPException(status_code=400, detail="Maximum 20 texts per request")
132
+
133
+ if inference_client is None and local_model is None:
134
+ raise HTTPException(status_code=503, detail="Model not loaded yet")
135
+
136
+ raw = run_hf_api(texts) if inference_client else run_local(texts)
137
+ results = parse_response(raw, texts)
138
+
139
+ normalized = []
140
+ for r in results:
141
+ sentiment = str(r.get("sentiment", "neutral")).lower()
142
+ if sentiment not in ("positive", "negative", "neutral"):
143
+ sentiment = "neutral"
144
+ normalized.append({
145
+ "text": r.get("text", ""),
146
+ "sentiment": sentiment,
147
+ "score": round(float(r.get("score", 0.5)), 4),
148
+ })
149
+
150
+ return {"results": normalized}
151
+
152
+
153
+ @app.get("/health")
154
+ def health():
155
+ mode = "hf_api" if inference_client else "local_int8" if local_model else "not_loaded"
156
+ return {"status": "ok", "mode": mode}