Arafath-ng commited on
Commit
358adf7
·
verified ·
1 Parent(s): eaa1860

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +31 -51
main.py CHANGED
@@ -1,59 +1,39 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from huggingface_hub import InferenceClient
4
- from fastapi.responses import StreamingResponse
5
- import uvicorn
6
 
7
  app = FastAPI()
8
 
9
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
10
-
11
- class Item(BaseModel):
12
- prompt: str
13
- history: list
14
- system_prompt: str
15
- temperature: float = 0.0
16
- max_new_tokens: int = 1048
17
- top_p: float = 0.15
18
- repetition_penalty: float = 1.0
19
-
20
- def format_prompt(message, history):
21
- print("````")
22
- print(message)
23
- print("++++")
24
- print(history)
25
- print("````")
26
- prompt = "<s>"
27
- for user_prompt, bot_response in history:
28
- prompt += f"[INST] {user_prompt} [/INST]"
29
- prompt += f" {bot_response}</s> "
30
- prompt += f"[INST] {message} [/INST]"
31
- return prompt
32
-
33
- async def generate_stream(item: Item):
34
- temperature = float(item.temperature)
35
- if temperature < 1e-2:
36
- temperature = 1e-2
37
- top_p = float(item.top_p)
38
-
39
- generate_kwargs = dict(
40
- temperature=temperature,
41
- max_new_tokens=item.max_new_tokens,
42
- top_p=top_p,
43
- repetition_penalty=item.repetition_penalty,
44
- do_sample=True,
45
- seed=42,
46
  )
47
 
48
- formatted_prompt = format_prompt(f"{item.system_prompt} [/INST] Ok..! </s> [INST] {item.prompt}", item.history)
49
- print(formatted_prompt)
50
- print("=======")
51
- print(item.history)
52
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
53
-
54
- for response in stream:
55
- yield response.token.text # Stream each token as it's received
56
 
57
- @app.post("/generate/")
58
- async def generate_text(item: Item):
59
- return StreamingResponse(generate_stream(item), media_type="text/plain")
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ from transformers import pipeline
 
 
4
 
5
  app = FastAPI()
6
 
7
+ # Load model once
8
+ classifier = pipeline(
9
+ "zero-shot-classification",
10
+ model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"
11
+ )
12
+
13
+ # Your classes
14
+ CANDIDATE_LABELS = [
15
+ "Garbage issue",
16
+ "Streetlight not working",
17
+ "Road damage",
18
+ "Water supply issue",
19
+ "Noise pollution",
20
+ "Flooding",
21
+ "Corruption",
22
+ "Other"
23
+ ]
24
+
25
+ class Query(BaseModel):
26
+ text: str
27
+
28
+ @app.post("/predict")
29
+ def predict(query: Query):
30
+ result = classifier(
31
+ query.text,
32
+ candidate_labels=CANDIDATE_LABELS,
33
+ multi_label=False # <-- single best class
 
 
 
 
 
 
 
 
 
 
34
  )
35
 
36
+ # Top-1 predicted label
37
+ predicted_label = result["labels"][0]
 
 
 
 
 
 
38
 
39
+ return {"label": predicted_label}