Amish Kushwaha commited on
Commit
27c4f25
·
1 Parent(s): ae29859

Fix bitsandbytes issue - attempt 2

Browse files
Files changed (1) hide show
  1. app.py +23 -13
app.py CHANGED
@@ -1,29 +1,39 @@
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
4
 
5
- # Load your Hugging Face model
6
- # model = pipeline("text-generation", model="devops-bda/Abap")
7
- model = pipeline(
8
- "text-generation",
9
- model="devops-bda/Abap",
10
- model_kwargs={"load_in_4bit": False} # Disable 4-bit quantization
11
  )
 
 
 
 
12
 
13
- # Initialize FastAPI app
14
  app = FastAPI()
15
 
16
- # Define input format
17
  class InputData(BaseModel):
18
  input_text: str
19
 
20
- # Health check endpoint
21
  @app.get("/health")
22
  async def health_check():
23
  return {"status": "ok", "message": "Model is ready"}
24
 
25
- # Define prediction endpoint
26
  @app.post("/predict")
27
  async def predict(data: InputData):
28
- result = model(data.input_text, max_length=500)
29
- return {"output": result}
 
1
+ import os
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
+ from transformers import (
5
+ AutoModelForCausalLM,
6
+ AutoTokenizer,
7
+ pipeline,
8
+ AutoConfig
9
+ )
10
+
11
+ # Load the configuration and remove any quantization config if present
12
+ config = AutoConfig.from_pretrained("devops-bda/Abap")
13
+ if hasattr(config, "quantization_config"):
14
+ config.quantization_config = None
15
 
16
+ # Load the model and tokenizer without 4-bit quantization
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ "devops-bda/Abap",
19
+ config=config,
20
+ load_in_4bit=False # explicitly disable 4-bit quantization
 
21
  )
22
+ tokenizer = AutoTokenizer.from_pretrained("devops-bda/Abap")
23
+
24
+ # Create a text-generation pipeline with the loaded model and tokenizer
25
+ text_gen_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
26
 
 
27
  app = FastAPI()
28
 
 
29
  class InputData(BaseModel):
30
  input_text: str
31
 
 
32
  @app.get("/health")
33
  async def health_check():
34
  return {"status": "ok", "message": "Model is ready"}
35
 
 
36
  @app.post("/predict")
37
  async def predict(data: InputData):
38
+ output = text_gen_pipeline(data.input_text, max_length=500)
39
+ return {"output": output}